mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-03 23:36:43 +08:00
Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1adb02e087 | ||
![]() |
4d4140aa99 | ||
![]() |
e31840a37d | ||
![]() |
7d2e16f2d3 | ||
![]() |
92cd935a16 | ||
![]() |
25ce27ce2d | ||
![]() |
527d084a4b | ||
![]() |
bb9f937bb0 | ||
![]() |
511fe88684 | ||
![]() |
75504ff201 | ||
![]() |
a556feb325 | ||
![]() |
d06f47f4b9 | ||
![]() |
8d4cc091b4 | ||
![]() |
d8f28cb843 | ||
![]() |
88861c219d | ||
![]() |
7ba6cf28d9 | ||
![]() |
c174cfdc6b | ||
![]() |
4f198a99dd | ||
![]() |
2a9c9fcc40 | ||
![]() |
835a85c8bf | ||
![]() |
fe5d9ffa61 | ||
![]() |
aac186dcc1 | ||
![]() |
42931f332f | ||
![]() |
8a04648c09 | ||
![]() |
854c033fb6 |
28
README.md
28
README.md
@@ -3,8 +3,8 @@
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
|
||||
[](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||
</p>
|
||||
@@ -19,6 +19,11 @@ MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It
|
||||
## What's new in Version 2.0.0?
|
||||
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
|
||||
|
||||
Don't forget to use the new v2 import paths:
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/v2"
|
||||
```
|
||||
|
||||
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
|
||||
- User and MQTTv5 Packet Properties
|
||||
- Topic Aliases
|
||||
@@ -108,6 +113,7 @@ Examples of running the broker with various configurations can be found in the [
|
||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||
|
||||
- `listeners.NewTCP(...)` - A TCP listener.
|
||||
- `listeners.NewUnixSock(...)` - A Unix Socket listener.
|
||||
- `listeners.NewWebsocket(...)` A Websocket listener.
|
||||
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
|
||||
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
@@ -291,7 +297,6 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
|
||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||
@@ -300,13 +305,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
|
||||
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
|
||||
|
||||
### Packet Injection
|
||||
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
|
||||
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
|
||||
### Direct Publish
|
||||
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
|
||||
|
||||
```go
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
|
||||
|
||||
### Packet Injection
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
|
||||
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
|
69
clients.go
69
clients.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -145,14 +146,10 @@ type ClientState struct {
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, o *ops) *Client {
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
cl := &Client{
|
||||
Net: ClientConnection{
|
||||
conn: c,
|
||||
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
@@ -165,46 +162,19 @@ func NewClient(c net.Conn, o *ops) *Client {
|
||||
ops: o,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
conn: c,
|
||||
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// NewInlineClient returns a client used when publishing from the embedding system.
|
||||
func NewInlineClient(id, remote string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Net: ClientConnection{
|
||||
Remote: remote,
|
||||
Inline: true,
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newClientStub returns an instance of Client with minimal initializations, such as
|
||||
// restoring client data from a db. In particular, the client is marked as offline (done).
|
||||
func newClientStub() *Client {
|
||||
return &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
done: 1,
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ParseConnect parses the connect parameters and properties for a client.
|
||||
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Net.Listener = lid
|
||||
@@ -320,17 +290,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
|
||||
var deleted int64
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted++
|
||||
deleted = append(deleted, uint16(tk.PacketID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
@@ -413,6 +384,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
@@ -502,8 +477,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
|
121
clients_test.go
121
clients_test.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -21,10 +22,10 @@ const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = NewClient(w, &ops{
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
@@ -118,34 +119,21 @@ func TestClientsGetByListener(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Nil(t, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestNewClientStub(t *testing.T) {
|
||||
cl := newClientStub()
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
|
||||
}
|
||||
|
||||
func TestNewInlineClient(t *testing.T) {
|
||||
cl := NewInlineClient("inline", "local")
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
|
||||
require.Equal(t, "inline", cl.ID)
|
||||
require.Equal(t, "local", cl.Net.Remote)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -182,7 +170,7 @@ func TestClientParseConnect(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -207,13 +195,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
@@ -225,7 +213,7 @@ func TestClientNextPacketID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
@@ -248,7 +236,7 @@ func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
for i := 0; i <= 65535; i++ {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
|
||||
}
|
||||
@@ -260,7 +248,7 @@ func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.packetID = uint32(65534)
|
||||
|
||||
@@ -274,7 +262,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
@@ -284,13 +272,15 @@ func TestClientClearInflights(t *testing.T) {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
cl.ClearInflights(n, 4)
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
@@ -310,7 +300,7 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
@@ -322,19 +312,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
@@ -349,7 +339,7 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -362,8 +352,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -377,7 +383,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -391,7 +397,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -445,7 +451,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
|
||||
@@ -460,15 +466,16 @@ func TestClientReadDone(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -485,7 +492,7 @@ func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -505,7 +512,7 @@ func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -537,7 +544,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
@@ -570,9 +577,17 @@ func TestClientReadPacket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
@@ -612,7 +627,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
@@ -621,7 +636,7 @@ func TestWriteClientOversizePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -641,7 +656,7 @@ func TestClientReadPacketReadingError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -660,7 +675,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
@@ -669,7 +684,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
@@ -677,7 +692,7 @@ func TestClientWritePacketWriteError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -51,15 +52,30 @@ func main() {
|
||||
// `server.Publish` method. Subscribe to `direct/publish` using your
|
||||
// MQTT client to see the messages.
|
||||
go func() {
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
for range time.Tick(time.Second * 10) {
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
for range time.Tick(time.Second * 1) {
|
||||
err := server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.InjectPacket")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
// There is also a shorthand convenience function, Publish, for easily sending
|
||||
// publish packets if you are not concerned with creating your own packets.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.Publish")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -28,7 +29,6 @@ func main() {
|
||||
server.Options.Capabilities.ServerKeepAlive = 60
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
25
hooks.go
25
hooks.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -46,7 +47,6 @@ const (
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
OnExpireInflights
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
@@ -95,7 +95,6 @@ type Hook interface {
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
OnExpireInflights(cl *Client, expiry int64)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
@@ -350,7 +349,7 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnMessage
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
@@ -413,8 +412,8 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
@@ -600,19 +599,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired
|
||||
// inflight messages. Expiry should be the time after which the message is no longer
|
||||
// valid (usually some time in the past). A message has expired if it's created time
|
||||
// is older than time.Now() minus Inflight TTL. This method can be used to expire
|
||||
// old inflight messages in a persistent store which doesnt support per-item TTL.
|
||||
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
|
||||
for _, hook := range h.internal {
|
||||
if hook.Provides(OnExpireInflights) {
|
||||
hook.OnExpireInflights(cl, expiry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
@@ -754,9 +740,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
|
||||
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
@@ -79,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -347,32 +347,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.Delete(m.ID, new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
@@ -381,6 +362,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
|
@@ -1,16 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
@@ -169,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -332,6 +346,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -418,48 +447,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -369,34 +368,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
}
|
||||
@@ -404,6 +382,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
|
@@ -1,10 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -211,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -340,6 +355,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -426,48 +456,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
@@ -82,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -363,37 +363,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
}
|
||||
|
||||
if d.Created < expiry || d.Created == 0 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
@@ -402,6 +378,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
@@ -252,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -391,6 +408,22 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -483,60 +516,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
n := time.Now().Unix()
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
|
||||
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
|
||||
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
|
||||
&storage.Message{ID: "i3", T: storage.InflightKey},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var r []storage.Message
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
err = d.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
r = append(r, d)
|
||||
}
|
||||
require.Len(t, r, 1)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -230,7 +231,6 @@ func TestHooksNonReturns(t *testing.T) {
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
h.OnExpireInflights(cl, time.Now().Unix()-1)
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
@@ -24,7 +25,7 @@ func TestInflightSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
@@ -33,7 +34,7 @@ func TestInflightGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
@@ -55,13 +56,13 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
@@ -162,7 +163,7 @@ func TestSendQuota(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
98
listeners/unixsock.go
Normal file
98
listeners/unixsock.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(id, address string) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: id,
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
96
listeners/unixsock_test.go
Normal file
96
listeners/unixsock_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
@@ -1,11 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -136,25 +138,35 @@ type wsConn struct {
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r.Read(p)
|
||||
var n, br int
|
||||
for {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// Code contains a reason code and reason string for a response.
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// TPacketCase contains data for cross-checking the encoding and decoding
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
116
server.go
116
server.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
|
||||
package mqtt
|
||||
|
||||
@@ -25,10 +26,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.0.0" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
|
||||
Version = "2.1.0" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -198,6 +199,31 @@ func (o *Options) ensureDefaults() {
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient returns a new Client instance, populated with all the required values and
|
||||
// references to be used with the server. If you are using this client to directly publish
|
||||
// messages from the embedding application, set the inline flag to true to bypass ACL and
|
||||
// topic validation checks.
|
||||
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
|
||||
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
|
||||
capabilities: s.Options.Capabilities,
|
||||
info: s.Info,
|
||||
hooks: s.hooks,
|
||||
log: s.Log,
|
||||
})
|
||||
|
||||
cl.ID = id
|
||||
cl.Net.Listener = listener
|
||||
|
||||
if inline { // inline clients bypass acl and some validity checks.
|
||||
cl.Net.Inline = true
|
||||
// By default we don't want to restrict developer publishes,
|
||||
// but if you do, reset this after creating inline client.
|
||||
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// AddHook attaches a new Hook to the server. Ideally, this should be called
|
||||
// before the server is started with s.Serve().
|
||||
func (s *Server) AddHook(hook Hook, config any) error {
|
||||
@@ -280,27 +306,21 @@ func (s *Server) eventLoop() {
|
||||
}
|
||||
|
||||
// EstablishConnection establishes a new client when a listener accepts a new connection.
|
||||
func (s *Server) EstablishConnection(lid string, c net.Conn) error {
|
||||
cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit
|
||||
capabilities: s.Options.Capabilities,
|
||||
info: s.Info,
|
||||
hooks: s.hooks,
|
||||
log: s.Log,
|
||||
})
|
||||
|
||||
return s.attachClient(cl, lid)
|
||||
func (s *Server) EstablishConnection(listener string, c net.Conn) error {
|
||||
cl := s.NewClient(c, listener, "", false)
|
||||
return s.attachClient(cl, listener)
|
||||
}
|
||||
|
||||
// attachClient validates an incoming client connection and if viable, attaches the client
|
||||
// to the server, performs session housekeeping, and reads incoming packets.
|
||||
func (s *Server) attachClient(cl *Client, lid string) error {
|
||||
func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
defer cl.Stop(nil)
|
||||
pk, err := s.readConnectionPacket(cl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read connection: %w", err)
|
||||
}
|
||||
|
||||
cl.ParseConnect(lid, pk)
|
||||
cl.ParseConnect(listener, pk)
|
||||
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
|
||||
if code != packets.CodeSuccess {
|
||||
if err := s.sendConnack(cl, code, false); err != nil {
|
||||
@@ -352,11 +372,11 @@ func (s *Server) attachClient(cl *Client, lid string) error {
|
||||
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
|
||||
}
|
||||
|
||||
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected")
|
||||
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.hooks.OnDisconnect(cl, err, expire)
|
||||
if expire {
|
||||
s.unsubscribeClient(cl)
|
||||
s.UnsubscribeClient(cl)
|
||||
cl.ClearInflights(math.MaxInt64, 0)
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
}
|
||||
@@ -435,7 +455,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
defer existing.Unlock()
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.unsubscribeClient(existing)
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
return false // [MQTT-3.2.2-3]
|
||||
}
|
||||
@@ -591,6 +611,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
|
||||
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
|
||||
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
|
||||
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
|
||||
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
|
||||
cl := s.NewClient(nil, "local", "inline", true)
|
||||
return s.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Qos: qos,
|
||||
Retain: retain,
|
||||
},
|
||||
TopicName: topic,
|
||||
Payload: payload,
|
||||
PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
|
||||
})
|
||||
}
|
||||
|
||||
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
|
||||
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
|
||||
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
||||
@@ -626,7 +664,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
pk.Origin = cl.ID
|
||||
pk.Created = time.Now().Unix()
|
||||
|
||||
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
|
||||
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok && !cl.Net.Inline {
|
||||
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
|
||||
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
|
||||
return cl.WritePacket(ack)
|
||||
@@ -659,6 +697,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -689,8 +728,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.hooks.OnPublish(cl, pk)
|
||||
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -924,7 +962,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
code = packets.ErrPacketIdentifierInUse
|
||||
}
|
||||
|
||||
existed := false
|
||||
filterExisted := make([]bool, len(pk.Filters))
|
||||
reasonCodes := make([]byte, len(pk.Filters))
|
||||
for i, sub := range pk.Filters {
|
||||
if code != packets.CodeSuccess {
|
||||
@@ -940,8 +978,8 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
|
||||
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
|
||||
} else {
|
||||
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||
if !existed {
|
||||
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||
if isNew {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, 1)
|
||||
}
|
||||
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
|
||||
@@ -950,6 +988,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
|
||||
}
|
||||
|
||||
filterExisted[i] = !isNew
|
||||
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
|
||||
}
|
||||
|
||||
@@ -984,7 +1023,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
continue
|
||||
}
|
||||
|
||||
s.publishRetainedToClient(cl, sub, existed)
|
||||
s.publishRetainedToClient(cl, sub, filterExisted[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1034,14 +1073,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
|
||||
// unsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) unsubscribeClient(cl *Client) {
|
||||
for k := range cl.State.Subscriptions.GetAll() {
|
||||
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) UnsubscribeClient(cl *Client) {
|
||||
i := 0
|
||||
filterMap := cl.State.Subscriptions.GetAll()
|
||||
filters := make([]packets.Subscription, len(filterMap))
|
||||
for k, v := range filterMap {
|
||||
cl.State.Subscriptions.Delete(k)
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
||||
}
|
||||
filters[i] = v
|
||||
i++
|
||||
}
|
||||
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
|
||||
}
|
||||
|
||||
// processAuth processes an Auth packet.
|
||||
@@ -1086,9 +1131,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
|
||||
out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
|
||||
}
|
||||
|
||||
// We already have a code we are using to disconnect the client, so we are not
|
||||
// interested if the write packet fails due to a closed connection (as we are closing it).
|
||||
err := cl.WritePacket(out)
|
||||
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
|
||||
cl.Stop(code)
|
||||
if code.Code >= packets.ErrUnspecifiedError.Code {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -1303,9 +1353,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
|
||||
// loadClients restores clients from the datastore.
|
||||
func (s *Server) loadClients(v []storage.Client) {
|
||||
for _, c := range v {
|
||||
cl := newClientStub()
|
||||
cl.ID = c.ID
|
||||
cl.Net.Listener = c.Listener
|
||||
cl := s.NewClient(nil, c.Listener, c.ID, false)
|
||||
cl.Properties.Username = c.Username
|
||||
cl.Properties.Clean = c.Clean
|
||||
cl.Properties.ProtocolVersion = c.ProtocolVersion
|
||||
@@ -1412,8 +1460,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
|
||||
// clearExpiredInflights deletes any inflight messages which have expired.
|
||||
func (s *Server) clearExpiredInflights(now int64) {
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
|
||||
s.hooks.OnExpireInflights(client, now)
|
||||
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
|
||||
for _, id := range deleted {
|
||||
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
296
server_test.go
296
server_test.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -101,7 +102,34 @@ func TestNewNilOpts(t *testing.T) {
|
||||
require.NotNil(t, s.Options)
|
||||
}
|
||||
|
||||
func TestAddHook(t *testing.T) {
|
||||
func TestServerNewClient(t *testing.T) {
|
||||
s := New(nil)
|
||||
s.Log = &logger
|
||||
r, _ := net.Pipe()
|
||||
|
||||
cl := s.NewClient(r, "testing", "test", false)
|
||||
require.NotNil(t, cl)
|
||||
require.Equal(t, "test", cl.ID)
|
||||
require.Equal(t, "testing", cl.Net.Listener)
|
||||
require.False(t, cl.Net.Inline)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.Equal(t, s.Log, cl.ops.log)
|
||||
}
|
||||
|
||||
func TestServerNewClientInline(t *testing.T) {
|
||||
s := New(nil)
|
||||
cl := s.NewClient(nil, "testing", "test", true)
|
||||
require.True(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestServerAddHook(t *testing.T) {
|
||||
s := New(nil)
|
||||
s.Log = &logger
|
||||
require.NotNil(t, s)
|
||||
@@ -112,7 +140,7 @@ func TestAddHook(t *testing.T) {
|
||||
require.Equal(t, int64(1), s.hooks.Len())
|
||||
}
|
||||
|
||||
func TestAddListener(t *testing.T) {
|
||||
func TestServerAddListener(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
@@ -127,7 +155,7 @@ func TestAddListener(t *testing.T) {
|
||||
require.Equal(t, ErrListenerIDExists, err)
|
||||
}
|
||||
|
||||
func TestAddListenerInitFailure(t *testing.T) {
|
||||
func TestServerAddListenerInitFailure(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
@@ -196,7 +224,7 @@ func TestServerReadConnectionPacket(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
o := make(chan packets.Packet)
|
||||
@@ -218,7 +246,7 @@ func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
o := make(chan error)
|
||||
@@ -241,7 +269,7 @@ func TestServerReadConnectionPacketBadPacketType(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
go func() {
|
||||
@@ -258,7 +286,7 @@ func TestServerReadConnectionPacketBadPacket(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
go func() {
|
||||
@@ -376,7 +404,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r0, _ := newClient()
|
||||
cl, r0, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
|
||||
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
@@ -437,7 +465,7 @@ func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl, r0, _ := newClient()
|
||||
cl, r0, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
|
||||
cl.State.Inflight = NewInflights()
|
||||
@@ -473,7 +501,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
|
||||
cl, r0, _ := newClient()
|
||||
cl, r0, _ := newTestClient()
|
||||
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
|
||||
cl.Properties.Clean = true
|
||||
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
@@ -659,7 +687,7 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) {
|
||||
|
||||
func TestServerSendConnack(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Options.Capabilities.ServerKeepAlive = 20
|
||||
s.Options.Capabilities.MaximumQos = 1
|
||||
@@ -679,7 +707,7 @@ func TestServerSendConnack(t *testing.T) {
|
||||
|
||||
func TestServerSendConnackFailureReason(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Options.Capabilities.ServerKeepAlive = 20
|
||||
go func() {
|
||||
@@ -757,7 +785,7 @@ func TestServerValidateConnect(t *testing.T) {
|
||||
|
||||
func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.Properties.Props.SessionExpiryInterval = uint32(300)
|
||||
s.Options.Capabilities.MaximumSessionExpiryInterval = 120
|
||||
@@ -777,7 +805,7 @@ func TestInheritClientSession(t *testing.T) {
|
||||
|
||||
n := time.Now().Unix()
|
||||
|
||||
existing, _, _ := newClient()
|
||||
existing, _, _ := newTestClient()
|
||||
existing.Net.conn = nil
|
||||
existing.ID = "mochi"
|
||||
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
|
||||
@@ -787,7 +815,7 @@ func TestInheritClientSession(t *testing.T) {
|
||||
|
||||
s.Clients.Add(existing)
|
||||
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
@@ -800,7 +828,7 @@ func TestInheritClientSession(t *testing.T) {
|
||||
require.Equal(t, 1, cl.State.Subscriptions.Len())
|
||||
|
||||
// On clean, clear existing properties
|
||||
cl, _, _ = newClient()
|
||||
cl, _, _ = newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl)
|
||||
require.False(t, b)
|
||||
@@ -810,27 +838,27 @@ func TestInheritClientSession(t *testing.T) {
|
||||
|
||||
func TestServerUnsubscribeClient(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
pk := packets.Subscription{Filter: "a/b/c", Qos: 1}
|
||||
cl.State.Subscriptions.Add("a/b/c", pk)
|
||||
s.Topics.Subscribe(cl.ID, pk)
|
||||
subs := s.Topics.Subscribers("a/b/c")
|
||||
require.Equal(t, 1, len(subs.Subscriptions))
|
||||
s.unsubscribeClient(cl)
|
||||
s.UnsubscribeClient(cl)
|
||||
subs = s.Topics.Subscribers("a/b/c")
|
||||
require.Equal(t, 0, len(subs.Subscriptions))
|
||||
}
|
||||
|
||||
func TestServerProcessPacketFailure(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := s.processPacket(cl, packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestServerProcessPacketConnect(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet)
|
||||
require.Error(t, err)
|
||||
@@ -838,7 +866,7 @@ func TestServerProcessPacketConnect(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPingreq(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
|
||||
@@ -853,7 +881,7 @@ func TestServerProcessPacketPingreq(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPingreqError(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(packets.CodeDisconnect)
|
||||
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
|
||||
@@ -863,7 +891,7 @@ func TestServerProcessPacketPingreqError(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPublishInvalid(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet)
|
||||
require.Error(t, err)
|
||||
@@ -875,12 +903,12 @@ func TestInjectPacketPublishAndReceive(t *testing.T) {
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
sender, _, w1 := newClient()
|
||||
sender, _, w1 := newTestClient()
|
||||
sender.Net.Inline = true
|
||||
sender.ID = "sender"
|
||||
s.Clients.Add(sender)
|
||||
|
||||
receiver, r2, w2 := newClient()
|
||||
receiver, r2, w2 := newTestClient()
|
||||
receiver.ID = "receiver"
|
||||
s.Clients.Add(receiver)
|
||||
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
|
||||
@@ -905,10 +933,46 @@ func TestInjectPacketPublishAndReceive(t *testing.T) {
|
||||
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
|
||||
}
|
||||
|
||||
func TestServerDirectPublishAndReceive(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
sender, _, w1 := newTestClient()
|
||||
sender.Net.Inline = true
|
||||
sender.ID = "sender"
|
||||
s.Clients.Add(sender)
|
||||
|
||||
receiver, r2, w2 := newTestClient()
|
||||
receiver.ID = "receiver"
|
||||
s.Clients.Add(receiver)
|
||||
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
|
||||
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
|
||||
|
||||
receiverBuf := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r2)
|
||||
require.NoError(t, err)
|
||||
receiverBuf <- buf
|
||||
}()
|
||||
|
||||
go func() {
|
||||
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||
err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
|
||||
require.NoError(t, err)
|
||||
w1.Close()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
w2.Close()
|
||||
}()
|
||||
|
||||
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
|
||||
}
|
||||
|
||||
func TestInjectPacketError(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Inline = true
|
||||
pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet
|
||||
pkx.Filters = packets.Subscriptions{}
|
||||
@@ -919,7 +983,7 @@ func TestInjectPacketError(t *testing.T) {
|
||||
func TestInjectPacketPublishInvalidTopic(t *testing.T) {
|
||||
s := newServer()
|
||||
defer s.Close()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Inline = true
|
||||
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
|
||||
pkx.TopicName = "$SYS/test"
|
||||
@@ -932,11 +996,11 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
sender, _, w1 := newClient()
|
||||
sender, _, w1 := newTestClient()
|
||||
sender.ID = "sender"
|
||||
s.Clients.Add(sender)
|
||||
|
||||
receiver, r2, w2 := newClient()
|
||||
receiver, r2, w2 := newTestClient()
|
||||
receiver.ID = "receiver"
|
||||
s.Clients.Add(receiver)
|
||||
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
|
||||
@@ -965,7 +1029,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketAndNextImmediate(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
|
||||
next.Expiry = -1
|
||||
@@ -992,7 +1056,7 @@ func TestServerProcessPacketPublishAckFailure(t *testing.T) {
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
cl, _, w := newClient()
|
||||
cl, _, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
w.Close()
|
||||
@@ -1006,14 +1070,15 @@ func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.ResetReceiveQuota(0)
|
||||
s.Clients.Add(cl)
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrReceiveMaximum)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
@@ -1026,7 +1091,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet)
|
||||
require.NoError(t, err) // $SYS topics should be ignored?
|
||||
}
|
||||
@@ -1039,7 +1104,7 @@ func TestServerProcessPublishACLCheckDeny(t *testing.T) {
|
||||
})
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
|
||||
require.NoError(t, err) // ACL check fails silently
|
||||
}
|
||||
@@ -1056,14 +1121,14 @@ func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) {
|
||||
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
|
||||
require.NoError(t, err) // packets rejected silently
|
||||
}
|
||||
|
||||
func TestServerProcessPacketPublishQos0(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
|
||||
@@ -1078,7 +1143,7 @@ func TestServerProcessPacketPublishQos0(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
|
||||
atomic.StoreInt64(&s.Info.Inflight, 1)
|
||||
|
||||
@@ -1096,7 +1161,7 @@ func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}})
|
||||
atomic.StoreInt64(&s.Info.Inflight, 1)
|
||||
@@ -1115,7 +1180,7 @@ func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPublishQos1(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
@@ -1130,7 +1195,7 @@ func TestServerProcessPacketPublishQos1(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPublishQos2(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
|
||||
@@ -1146,7 +1211,7 @@ func TestServerProcessPacketPublishQos2(t *testing.T) {
|
||||
func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.MaximumQos = 1
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
|
||||
@@ -1161,7 +1226,7 @@ func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true})
|
||||
require.True(t, subbed)
|
||||
@@ -1186,11 +1251,11 @@ func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribers(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r1, w1 := newClient()
|
||||
cl, r1, w1 := newTestClient()
|
||||
cl.ID = "cl1"
|
||||
cl2, r2, w2 := newClient()
|
||||
cl2, r2, w2 := newTestClient()
|
||||
cl2.ID = "cl2"
|
||||
cl3, r3, w3 := newClient()
|
||||
cl3, r3, w3 := newTestClient()
|
||||
cl3.ID = "cl3"
|
||||
s.Clients.Add(cl)
|
||||
s.Clients.Add(cl2)
|
||||
@@ -1248,7 +1313,7 @@ func TestPublishToSubscribers(t *testing.T) {
|
||||
func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.MaximumMessageExpiryInterval = 86400
|
||||
cl, r1, w1 := newClient()
|
||||
cl, r1, w1 := newTestClient()
|
||||
cl.ID = "cl1"
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Clients.Add(cl)
|
||||
@@ -1277,7 +1342,7 @@ func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribersIdentifiers(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Clients.Add(cl)
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2})
|
||||
@@ -1307,7 +1372,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.MaximumQos = 1
|
||||
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
_, ok := cl.State.Inflight.Get(1)
|
||||
@@ -1333,7 +1398,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
|
||||
|
||||
func TestPublishToClientServerTopicAlias(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.Properties.Props.TopicAliasMaximum = 5
|
||||
s.Clients.Add(cl)
|
||||
@@ -1362,7 +1427,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
|
||||
|
||||
func TestPublishToClientExhaustedPacketID(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
for i := 0; i <= 65535; i++ {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
|
||||
}
|
||||
@@ -1374,7 +1439,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) {
|
||||
|
||||
func TestPublishToClientNoConn(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.conn = nil
|
||||
|
||||
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
|
||||
@@ -1384,12 +1449,12 @@ func TestPublishToClientNoConn(t *testing.T) {
|
||||
|
||||
func TestProcessPublishWithTopicAlias(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
|
||||
require.True(t, subbed)
|
||||
|
||||
cl2, _, w2 := newClient()
|
||||
cl2, _, w2 := newTestClient()
|
||||
cl2.Properties.ProtocolVersion = 5
|
||||
cl2.State.TopicAliases.Inbound.Set(1, "a/b/c")
|
||||
|
||||
@@ -1411,7 +1476,7 @@ func TestProcessPublishWithTopicAlias(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
cl.State.Inflight.sendQuota = 0
|
||||
|
||||
@@ -1430,7 +1495,7 @@ func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
for i := 0; i <= 65535; i++ {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
@@ -1451,7 +1516,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
|
||||
|
||||
func TestPublishToSubscribersNoConnection(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
|
||||
require.True(t, subbed)
|
||||
@@ -1466,7 +1531,7 @@ func TestPublishToSubscribersNoConnection(t *testing.T) {
|
||||
|
||||
func TestPublishRetainedToClient(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
|
||||
@@ -1488,7 +1553,7 @@ func TestPublishRetainedToClient(t *testing.T) {
|
||||
|
||||
func TestPublishRetainedToClientIsShared(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"}
|
||||
@@ -1507,7 +1572,7 @@ func TestPublishRetainedToClientIsShared(t *testing.T) {
|
||||
|
||||
func TestPublishRetainedToClientError(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, w := newClient()
|
||||
cl, _, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
sub := packets.Subscription{Filter: "a/b/c"}
|
||||
@@ -1537,7 +1602,7 @@ func TestServerProcessPacketPuback(t *testing.T) {
|
||||
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
|
||||
@@ -1559,7 +1624,7 @@ func TestServerProcessPacketPuback(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
|
||||
@@ -1574,7 +1639,7 @@ func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
|
||||
func TestServerProcessPacketPubrec(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
|
||||
@@ -1603,7 +1668,7 @@ func TestServerProcessPacketPubrec(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
@@ -1629,7 +1694,7 @@ func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
|
||||
func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet)
|
||||
require.NoError(t, err)
|
||||
@@ -1641,7 +1706,7 @@ func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
|
||||
func TestServerProcessPacketPubrecFailure(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
|
||||
cl.Stop(packets.CodeDisconnect)
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
|
||||
@@ -1652,7 +1717,7 @@ func TestServerProcessPacketPubrecFailure(t *testing.T) {
|
||||
func TestServerProcessPacketPubrel(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
|
||||
@@ -1682,7 +1747,7 @@ func TestServerProcessPacketPubrel(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
@@ -1708,7 +1773,7 @@ func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
|
||||
func TestServerProcessPacketPubrelFailure(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
|
||||
cl.Stop(packets.CodeDisconnect)
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
|
||||
@@ -1719,7 +1784,7 @@ func TestServerProcessPacketPubrelFailure(t *testing.T) {
|
||||
func TestServerProcessPacketPubrelBadReason(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet)
|
||||
require.NoError(t, err)
|
||||
@@ -1744,7 +1809,7 @@ func TestServerProcessPacketPubcomp(t *testing.T) {
|
||||
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = tx.protocolVersion
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
@@ -1791,7 +1856,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) {
|
||||
|
||||
pID := uint16(7)
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
|
||||
@@ -1862,7 +1927,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
|
||||
|
||||
pID := uint16(6)
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.packetID = uint32(6)
|
||||
cl.State.Inflight.sendQuota = 3
|
||||
cl.State.Inflight.receiveQuota = 3
|
||||
@@ -1906,7 +1971,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketSubscribe(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet)
|
||||
@@ -1921,7 +1986,7 @@ func TestServerProcessPacketSubscribe(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
|
||||
|
||||
@@ -1940,7 +2005,7 @@ func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet)
|
||||
@@ -1950,7 +2015,7 @@ func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
go func() {
|
||||
@@ -1966,7 +2031,7 @@ func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
go func() {
|
||||
@@ -1982,7 +2047,7 @@ func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeWithRetain(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||
require.Equal(t, int64(1), retained)
|
||||
@@ -2006,7 +2071,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) {
|
||||
func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
|
||||
s := newServer()
|
||||
s.Options.Capabilities.MaximumQos = 1
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet)
|
||||
@@ -2023,7 +2088,7 @@ func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})
|
||||
s.Clients.Add(cl)
|
||||
|
||||
@@ -2045,7 +2110,7 @@ func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||
@@ -2066,7 +2131,7 @@ func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
|
||||
@@ -2090,7 +2155,7 @@ func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeNoConnection(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
|
||||
require.Error(t, err)
|
||||
@@ -2104,7 +2169,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
|
||||
FanPoolQueueSize: 10,
|
||||
})
|
||||
s.Serve()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
go func() {
|
||||
@@ -2126,7 +2191,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
|
||||
})
|
||||
s.Serve()
|
||||
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
go func() {
|
||||
@@ -2142,7 +2207,7 @@ func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
|
||||
|
||||
func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 3
|
||||
cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures
|
||||
|
||||
@@ -2159,7 +2224,7 @@ func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketUnsubscribe(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0})
|
||||
go func() {
|
||||
@@ -2176,7 +2241,7 @@ func TestServerProcessPacketUnsubscribe(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
|
||||
go func() {
|
||||
@@ -2193,7 +2258,7 @@ func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
|
||||
@@ -2201,7 +2266,7 @@ func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
|
||||
|
||||
func TestServerReceivePacketError(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
|
||||
@@ -2209,7 +2274,7 @@ func TestServerReceivePacketError(t *testing.T) {
|
||||
|
||||
func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
cl.Properties.Props.SessionExpiryInterval = 0
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.Properties.Props.RequestProblemInfo = 0
|
||||
@@ -2226,9 +2291,24 @@ func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
|
||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestServerRecievePacketDisconnectClient(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.DisconnectClient(cl, packets.CodeDisconnect)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestServerProcessPacketDisconnect(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.SessionExpiryInterval = 30
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
|
||||
@@ -2245,7 +2325,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.SessionExpiryInterval = 0
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
cl.Properties.Props.RequestProblemInfo = 0
|
||||
@@ -2258,7 +2338,7 @@ func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketAuth(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
go func() {
|
||||
err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
|
||||
@@ -2273,7 +2353,7 @@ func TestServerProcessPacketAuth(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet
|
||||
pkx.ReasonCode = 99
|
||||
err := s.processPacket(cl, pkx)
|
||||
@@ -2283,7 +2363,7 @@ func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
|
||||
|
||||
func TestServerProcessPacketAuthFailure(t *testing.T) {
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
hook.fail = true
|
||||
@@ -2300,7 +2380,7 @@ func TestServerSendLWT(t *testing.T) {
|
||||
s.Serve()
|
||||
defer s.Close()
|
||||
|
||||
sender, _, w1 := newClient()
|
||||
sender, _, w1 := newTestClient()
|
||||
sender.ID = "sender"
|
||||
sender.Properties.Will = Will{
|
||||
Flag: 1,
|
||||
@@ -2309,7 +2389,7 @@ func TestServerSendLWT(t *testing.T) {
|
||||
}
|
||||
s.Clients.Add(sender)
|
||||
|
||||
receiver, r2, w2 := newClient()
|
||||
receiver, r2, w2 := newTestClient()
|
||||
receiver.ID = "receiver"
|
||||
s.Clients.Add(receiver)
|
||||
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
|
||||
@@ -2336,7 +2416,7 @@ func TestServerSendLWT(t *testing.T) {
|
||||
|
||||
func TestServerSendLWTDelayed(t *testing.T) {
|
||||
s := newServer()
|
||||
cl1, _, _ := newClient()
|
||||
cl1, _, _ := newTestClient()
|
||||
cl1.ID = "cl1"
|
||||
cl1.Properties.Will = Will{
|
||||
Flag: 1,
|
||||
@@ -2347,7 +2427,7 @@ func TestServerSendLWTDelayed(t *testing.T) {
|
||||
}
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
cl2, r, w := newClient()
|
||||
cl2, r, w := newTestClient()
|
||||
cl2.ID = "cl2"
|
||||
s.Clients.Add(cl2)
|
||||
require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"}))
|
||||
@@ -2425,7 +2505,7 @@ func TestServerLoadSubscriptions(t *testing.T) {
|
||||
}
|
||||
|
||||
s := newServer()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
s.Clients.Add(cl)
|
||||
require.Equal(t, 0, cl.State.Subscriptions.Len())
|
||||
s.loadSubscriptions(v)
|
||||
@@ -2485,7 +2565,7 @@ func TestServerClose(t *testing.T) {
|
||||
hook := new(modifiedHookBase)
|
||||
s.AddHook(hook, nil)
|
||||
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
cl.Net.Listener = "t1"
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
s.Clients.Add(cl)
|
||||
@@ -2523,7 +2603,7 @@ func TestServerClearExpiredInflights(t *testing.T) {
|
||||
s.Options.Capabilities.MaximumMessageExpiryInterval = 4
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ops.info = s.Info
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
@@ -2563,12 +2643,12 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
|
||||
n := time.Now().Unix()
|
||||
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ID = "cl"
|
||||
s.Clients.Add(cl)
|
||||
|
||||
// No Expiry
|
||||
cl0, _, _ := newClient()
|
||||
cl0, _, _ := newTestClient()
|
||||
cl0.ID = "c0"
|
||||
cl0.State.disconnected = n - 10
|
||||
cl0.State.done = 1
|
||||
@@ -2578,7 +2658,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
s.Clients.Add(cl0)
|
||||
|
||||
// Normal Expiry
|
||||
cl1, _, _ := newClient()
|
||||
cl1, _, _ := newTestClient()
|
||||
cl1.ID = "c1"
|
||||
cl1.State.disconnected = n - 10
|
||||
cl1.State.done = 1
|
||||
@@ -2588,7 +2668,7 @@ func TestServerClearExpiredClients(t *testing.T) {
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
// No Expiry, indefinite session
|
||||
cl2, _, _ := newClient()
|
||||
cl2, _, _ := newTestClient()
|
||||
cl2.ID = "c2"
|
||||
cl2.State.disconnected = n - 10
|
||||
cl2.State.done = 1
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package system
|
||||
|
||||
// Info contains atomic counters and values for various server statistics
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
Reference in New Issue
Block a user