Compare commits

..

46 Commits

Author SHA1 Message Date
mochi-co
f90e52328d Update server version 2023-01-16 20:08:55 +00:00
JB
50aae47618 Publish retained messages only after connack (#147) 2023-01-16 19:50:01 +00:00
JB
0d79f2d63b Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
2023-01-16 19:49:36 +00:00
JB
300152413c Ignore retain as published v3 (#142)
* Optimise Capabilities struct alignment

* Only use RetainAsPublished for v5 clients
2023-01-13 23:38:49 +00:00
mochi-co
0de1d731db Update version number 2023-01-10 00:01:21 +00:00
JB
80746abc52 Use correct connack return codes for MQTTv3 (#140) 2023-01-10 00:00:43 +00:00
mochi-co
a73cf4ca0e Update server version 2023-01-09 23:08:49 +00:00
mochi-co
bc549ee7ed Fix example imports 2023-01-09 22:52:24 +00:00
mochi-co
c464b46713 export client.Net.Conn for external use 2023-01-09 22:49:40 +00:00
mochi-co
05ce56008c Small code improvements 2023-01-09 22:49:20 +00:00
JB
8254cb0cbc Make hooks safe for concurrency (#139)
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
2023-01-09 22:41:44 +00:00
mochi-co
4ae58b79e3 Update server version 2023-01-07 20:13:48 +00:00
thedevop
b895d688e0 Change inline check order (#133) 2023-01-07 20:02:05 +00:00
mochi-co
a600cd4ead fix grammar on Closed method doc 2023-01-07 17:57:04 +00:00
mochi-co
cdb44990cf Update version number 2023-01-07 17:30:58 +00:00
mochi-co
2d9c128111 Refactor stored subscription value assignments 2023-01-07 17:30:30 +00:00
mochi-co
a0d5bdb39f Fix Typos 2023-01-07 17:30:01 +00:00
Wind
4ebcef3cb6 Save subscription properties for mqttv5 (#131)
* Update redis.go

Save the subscription properties for mqqtv5

* Update badger.go

Save the subscription properties for mqqtv5

* Update bolt.go

Save the subscription properties for mqqtv5

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-01-07 17:24:23 +00:00
thedevop
fb8d4720d7 Add Client Closed (#130)
* Add Client Closed
* Add Client Closed
* Update clients_test.go
2023-01-07 17:14:51 +00:00
JB
4080c89127 Update README.md 2022-12-21 21:00:53 +00:00
JB
1b67e6f3f6 Update README.md 2022-12-21 20:58:25 +00:00
mochi-co
1adb02e087 Update readme and server version 2022-12-21 20:47:58 +00:00
JB
4d4140aa99 Connect ReturnResponseInfo only applies to Connack values (#128) 2022-12-21 20:37:08 +00:00
JB
e31840a37d Optimize inflight expiry (#127)
* Small formatting/style changes for filter existed

* Use OnQoSDropped hook instead of onInflightExpired
2022-12-21 19:44:25 +00:00
JB
7d2e16f2d3 Merge pull request #123 from wind-c/master
Variable existed in the method processSubscribe is unstable
2022-12-21 11:41:14 +00:00
JB
92cd935a16 Merge branch 'master' into master 2022-12-21 11:38:28 +00:00
JB
25ce27ce2d Merge pull request #124 from zgwit/master
Add unix socket listener
2022-12-21 11:28:23 +00:00
jason
527d084a4b Add unix socket listener 2022-12-20 23:02:59 +08:00
Wind
bb9f937bb0 Variable existed in the method processSubscribe is unstable
The variable existed can be changed repeatedly within a for loop. An array variable must be used to record the subscription of each filter.
2022-12-18 13:46:06 +08:00
Wind
511fe88684 Merge branch 'mochi-co:master' into master 2022-12-17 12:33:09 +08:00
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
“Wind”
d06f47f4b9 Add the OnUnsubscribed hook to the unsubscribeClient method
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-17 00:40:06 +08:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
JB
d8f28cb843 Enforce server max packet (#121)
* Enforce Server Maximum Packet Size on client read
* Fix tests
2022-12-16 00:30:23 +00:00
JB
88861c219d Merge pull request #116 from tommyminds/bugfix/ws_malformed_package
Fix websocket malformed packet bug
2022-12-15 18:21:53 +00:00
JB
7ba6cf28d9 Merge branch 'master' into bugfix/ws_malformed_package 2022-12-15 18:21:33 +00:00
JB
c174cfdc6b Merge pull request #119 from mochi-co/fix-on-published
Fix mis-typed onpublished hook, update version, fanpool defaults
2022-12-15 18:21:19 +00:00
mochi-co
4f198a99dd Fix mis-typed onpublished hook, update version, fanpool defaults 2022-12-15 18:19:02 +00:00
Tommy Maintz
2a9c9fcc40 Fix websocket malformed packet bug 2022-12-14 21:41:33 +01:00
JB
835a85c8bf Update README.md 2022-12-12 11:44:36 +00:00
mochi-co
fe5d9ffa61 Simplify Client construction, add NewClient method to Server, add Publish convenience method 2022-12-12 11:37:19 +00:00
mochi-co
aac186dcc1 Add newline for godoc formatting 2022-12-11 22:25:21 +00:00
JB
42931f332f Update badges to use v2 references 2022-12-11 21:44:44 +00:00
mochi-co
8a04648c09 Cleanup godoc formatting 2022-12-11 21:38:01 +00:00
JB
854c033fb6 Update README.md 2022-12-11 12:21:25 +00:00
65 changed files with 1025 additions and 628 deletions

View File

@@ -2,9 +2,9 @@
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master&v2)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
</p>
@@ -19,6 +19,11 @@ MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It
## What's new in Version 2.0.0?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
Don't forget to use the new v2 import paths:
```go
import "github.com/mochi-co/mqtt/v2"
```
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
- User and MQTTv5 Packet Properties
- Topic Aliases
@@ -78,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"log"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
if err != nil {
@@ -107,10 +116,14 @@ Examples of running the broker with various configurations can be found in the [
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
| Listener | Usage |
| --- | --- |
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
A `*listeners.Config` may be passed to configure TLS.
@@ -291,7 +304,6 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
@@ -300,13 +312,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Packet Injection
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
### Direct Publish
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
```go
cl := mqtt.NewInlineClient("inline", "local")
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
```
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
### Packet Injection
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
```go
cl := server.NewClient(nil, "local", "inline", true)
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -105,7 +106,7 @@ type Client struct {
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
@@ -145,14 +146,10 @@ type ClientState struct {
keepalive uint16 // the number of seconds the connection can wait
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, o *ops) *Client {
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
cl := &Client{
Net: ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
@@ -165,46 +162,19 @@ func NewClient(c net.Conn, o *ops) *Client {
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// NewInlineClient returns a client used when publishing from the embedding system.
func NewInlineClient(id, remote string) *Client {
return &Client{
ID: id,
Net: ClientConnection{
Remote: remote,
Inline: true,
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// newClientStub returns an instance of Client with minimal initializations, such as
// restoring client data from a db. In particular, the client is marked as offline (done).
func newClientStub() *Client {
return &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
done: 1,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
@@ -253,13 +223,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
if cl.Net.Conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -320,17 +290,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
@@ -370,8 +341,8 @@ func (cl *Client) Stop(err error) {
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
@@ -391,6 +362,11 @@ func (cl *Client) StopCause() error {
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
@@ -413,6 +389,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
@@ -484,7 +464,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
if cl.Net.Conn == nil {
return nil
}
@@ -502,8 +482,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
@@ -553,7 +533,7 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -21,10 +22,10 @@ const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newClient() (cl *Client, r net.Conn, w net.Conn) {
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = NewClient(w, &ops{
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
@@ -118,34 +119,21 @@ func TestClientsGetByListener(t *testing.T) {
}
func TestNewClient(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Nil(t, cl.StopCause())
}
func TestNewClientStub(t *testing.T) {
cl := newClientStub()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
}
func TestNewInlineClient(t *testing.T) {
cl := NewInlineClient("inline", "local")
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
require.Equal(t, "inline", cl.ID)
require.Equal(t, "local", cl.Net.Remote)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -182,7 +170,7 @@ func TestClientParseConnect(t *testing.T) {
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -207,13 +195,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
@@ -225,7 +213,7 @@ func TestClientNextPacketID(t *testing.T) {
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
@@ -248,7 +236,7 @@ func TestClientNextPacketIDInUse(t *testing.T) {
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
@@ -260,7 +248,7 @@ func TestClientNextPacketIDExhausted(t *testing.T) {
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.packetID = uint32(65534)
@@ -274,7 +262,7 @@ func TestClientNextPacketIDOverflow(t *testing.T) {
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
@@ -284,13 +272,15 @@ func TestClientClearInflights(t *testing.T) {
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newClient()
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
@@ -310,7 +300,7 @@ func TestClientResendInflightMessages(t *testing.T) {
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newClient()
cl, r, _ := newTestClient()
r.Close()
cl.State.Inflight.Set(*pk1.Packet)
@@ -322,19 +312,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -349,7 +339,7 @@ func TestClientReadFixedHeader(t *testing.T) {
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -362,8 +352,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -377,7 +383,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -391,7 +397,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -445,7 +451,7 @@ func TestClientReadOK(t *testing.T) {
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
@@ -460,15 +466,23 @@ func TestClientReadDone(t *testing.T) {
}
func TestClientStop(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -485,7 +499,7 @@ func TestClientReadFixedHeaderError(t *testing.T) {
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -505,7 +519,7 @@ func TestClientReadReadHandlerErr(t *testing.T) {
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -537,7 +551,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
@@ -570,9 +584,17 @@ func TestClientReadPacket(t *testing.T) {
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
@@ -588,7 +610,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -612,7 +634,7 @@ func TestClientWritePacket(t *testing.T) {
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
@@ -621,7 +643,7 @@ func TestWriteClientOversizePacket(t *testing.T) {
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -641,7 +663,7 @@ func TestClientReadPacketReadingError(t *testing.T) {
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -660,7 +682,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
@@ -669,15 +691,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newClient()
cl.Net.conn.Close()
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -51,15 +52,30 @@ func main() {
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := mqtt.NewInlineClient("inline", "local")
for range time.Tick(time.Second * 10) {
server.InjectPacket(cl, packets.Packet{
cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 1) {
err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
Payload: []byte("injected scheduled message"),
})
if err != nil {
server.Log.Error().Err(err).Msg("server.InjectPacket")
}
server.Log.Info().Msgf("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error().Err(err).Msg("server.Publish")
}
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
}
}()

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -28,7 +29,6 @@ func main() {
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (

122
hooks.go
View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
import (
@@ -46,7 +47,6 @@ const (
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
@@ -95,7 +95,6 @@ type Hook interface {
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
@@ -111,10 +110,10 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -124,7 +123,7 @@ func (h *Hooks) Len() int64 {
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
@@ -139,26 +138,39 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
@@ -173,7 +185,7 @@ func (h *Hooks) Stop() {
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
@@ -182,7 +194,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
@@ -191,7 +203,7 @@ func (h *Hooks) OnStarted() {
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
@@ -200,7 +212,7 @@ func (h *Hooks) OnStopped() {
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
@@ -209,7 +221,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
@@ -218,7 +230,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
@@ -228,7 +240,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -249,7 +261,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
@@ -265,7 +277,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
@@ -276,7 +288,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
@@ -286,7 +298,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
@@ -298,7 +310,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
@@ -308,7 +320,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
@@ -320,7 +332,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
@@ -333,7 +345,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
@@ -343,19 +355,19 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnMessage
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -374,7 +386,7 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
@@ -383,7 +395,7 @@ func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
@@ -394,7 +406,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
@@ -405,7 +417,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
@@ -413,10 +425,10 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
@@ -428,7 +440,7 @@ func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
@@ -444,7 +456,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
@@ -453,7 +465,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
@@ -462,7 +474,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
@@ -472,7 +484,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
@@ -492,7 +504,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
@@ -512,7 +524,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
@@ -532,7 +544,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
@@ -551,7 +563,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
@@ -573,7 +585,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
@@ -589,7 +601,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
@@ -600,19 +612,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
@@ -754,9 +753,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
@@ -79,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -198,11 +198,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
@@ -347,32 +351,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
@@ -381,6 +366,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")

View File

@@ -1,16 +1,15 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"errors"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
@@ -169,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -332,6 +346,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -418,48 +447,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -201,12 +200,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
@@ -369,34 +373,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
@@ -404,6 +387,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -1,10 +1,10 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"errors"
"os"
"testing"
"time"
@@ -211,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -340,6 +355,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -426,48 +456,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
@@ -82,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -215,11 +215,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
@@ -363,37 +367,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
@@ -402,6 +382,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
@@ -252,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -391,6 +408,22 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -483,60 +516,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -26,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -177,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
@@ -230,7 +243,6 @@ func TestHooksNonReturns(t *testing.T) {
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -12,7 +13,7 @@ import (
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
@@ -24,7 +25,7 @@ func TestInflightSet(t *testing.T) {
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
@@ -33,7 +34,7 @@ func TestInflightGet(t *testing.T) {
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
@@ -55,13 +56,13 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
@@ -162,7 +163,7 @@ func TestSendQuota(t *testing.T) {
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

98
listeners/unixsock.go Normal file
View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -1,11 +1,13 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
@@ -136,25 +138,35 @@ type wsConn struct {
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
return 0, err
}
return r.Read(p)
var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
return 0, err
}
return len(p), nil

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
@@ -123,4 +124,23 @@ var (
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// TPacketCase contains data for cross-checking the encoding and decoding
@@ -88,6 +89,7 @@ const (
TConnackServerUnavailable
TConnackBadUsernamePassword
TConnackBadUsernamePasswordNoSession
TConnackMqtt5BadUsernamePasswordNoSession
TConnackNotAuthorised
TConnackMalSessionPresent
TConnackMalReturnCode
@@ -1315,10 +1317,28 @@ var TPacketData = map[byte]TPacketCases{
Desc: "bad username or password no session",
RawBytes: []byte{
Connack << 4, 2, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0, // No session present
Err3NotAuthorized.Code, // use v3 remapping
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
},
ReasonCode: Err3NotAuthorized.Code,
},
},
{
Case: TConnackMqtt5BadUsernamePasswordNoSession,
Desc: "mqtt5 bad username or password no session",
RawBytes: []byte{
Connack << 4, 3, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
@@ -1326,6 +1346,7 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrBadUsernameOrPassword.Code,
},
},
{
Case: TConnackNotAuthorised,
Desc: "not authorised",
@@ -1803,13 +1824,10 @@ var TPacketData = map[byte]TPacketCases{
Case: TPublishRetainMqtt5,
Desc: "retain mqtt5",
RawBytes: []byte{
Publish<<4 | 1<<0, 35, // Fixed header
Publish<<4 | 1<<0, 19, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
16, // properties length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
0, // properties length
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
@@ -1817,18 +1835,11 @@ var TPacketData = map[byte]TPacketCases{
FixedHeader: FixedHeader{
Type: Publish,
Retain: true,
Remaining: 35,
Remaining: 19,
},
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
TopicName: "a/b/c",
Properties: Properties{},
Payload: []byte("hello mochi"),
},
},
{

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

160
server.go
View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
package mqtt
@@ -25,10 +26,10 @@ import (
)
const (
Version = "2.0.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
Version = "2.1.6" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
)
var (
@@ -60,13 +61,13 @@ type Capabilities struct {
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
MaximumQos byte
RetainAvailable byte
WildcardSubAvailable byte
SubIDAvailable byte
SharedSubAvailable byte
MinimumProtocolVersion byte
}
// Compatibilities provides flags for using compatibility modes.
@@ -198,6 +199,31 @@ func (o *Options) ensureDefaults() {
}
}
// NewClient returns a new Client instance, populated with all the required values and
// references to be used with the server. If you are using this client to directly publish
// messages from the embedding application, set the inline flag to true to bypass ACL and
// topic validation checks.
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
cl.ID = id
cl.Net.Listener = listener
if inline { // inline clients bypass acl and some validity checks.
cl.Net.Inline = true
// By default we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
}
return cl
}
// AddHook attaches a new Hook to the server. Ideally, this should be called
// before the server is started with s.Serve().
func (s *Server) AddHook(hook Hook, config any) error {
@@ -280,27 +306,21 @@ func (s *Server) eventLoop() {
}
// EstablishConnection establishes a new client when a listener accepts a new connection.
func (s *Server) EstablishConnection(lid string, c net.Conn) error {
cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
return s.attachClient(cl, lid)
func (s *Server) EstablishConnection(listener string, c net.Conn) error {
cl := s.NewClient(c, listener, "", false)
return s.attachClient(cl, listener)
}
// attachClient validates an incoming client connection and if viable, attaches the client
// to the server, performs session housekeeping, and reads incoming packets.
func (s *Server) attachClient(cl *Client, lid string) error {
func (s *Server) attachClient(cl *Client, listener string) error {
defer cl.Stop(nil)
pk, err := s.readConnectionPacket(cl)
if err != nil {
return fmt.Errorf("read connection: %w", err)
}
cl.ParseConnect(lid, pk)
cl.ParseConnect(listener, pk)
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
if code != packets.CodeSuccess {
if err := s.sendConnack(cl, code, false); err != nil {
@@ -331,6 +351,13 @@ func (s *Server) attachClient(cl *Client, lid string) error {
return fmt.Errorf("ack connection packet: %w", err)
}
// Publish any retained messages for subscriptions which already existed on client takeover.
if sessionPresent {
for _, sub := range cl.State.Subscriptions.GetAll() {
s.publishRetainedToClient(cl, sub, true)
}
}
s.loop.willDelayed.Delete(cl.ID) // [MQTT-3.1.3-9]
if sessionPresent {
@@ -352,11 +379,11 @@ func (s *Server) attachClient(cl *Client, lid string) error {
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
}
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected")
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
if expire {
s.unsubscribeClient(cl)
s.UnsubscribeClient(cl)
cl.ClearInflights(math.MaxInt64, 0)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
}
@@ -435,7 +462,7 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.unsubscribeClient(existing)
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3]
}
@@ -447,7 +474,6 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub)
s.publishRetainedToClient(cl, sub, existed)
}
return true // [MQTT-3.2.2-3]
@@ -469,6 +495,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
}
if reason.Code >= packets.ErrUnspecifiedError.Code {
if cl.Properties.ProtocolVersion < 5 {
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
reason = v3reason
}
}
properties.ReasonString = reason.Reason
ack := packets.Packet{
FixedHeader: packets.FixedHeader{
@@ -591,6 +623,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
})
}
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
cl := s.NewClient(nil, "local", "inline", true)
return s.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: qos,
Retain: retain,
},
TopicName: topic,
Payload: payload,
PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
})
}
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
@@ -611,7 +661,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
// processPublish processes a Publish packet.
func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) {
return nil
}
@@ -619,20 +669,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8]
}
if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
return nil
}
pk.Origin = cl.ID
pk.Created = time.Now().Unix()
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
if !cl.Net.Inline {
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
}
}
}
@@ -659,6 +711,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk)
})
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -689,8 +742,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
s.publishToSubscribers(pk)
})
s.hooks.OnPublish(cl, pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -733,13 +785,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
}
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
if sub.NoLocal && pk.Origin == cl.ID {
return pk, nil // [MQTT-3.8.3-3]
}
out = pk.Copy(false)
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out := pk.Copy(false)
if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
}
@@ -788,7 +840,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
}
}
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Net.Conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return pk, packets.CodeDisconnect
}
@@ -924,7 +976,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
code = packets.ErrPacketIdentifierInUse
}
existed := false
filterExisted := make([]bool, len(pk.Filters))
reasonCodes := make([]byte, len(pk.Filters))
for i, sub := range pk.Filters {
if code != packets.CodeSuccess {
@@ -940,8 +992,8 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else {
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if isNew {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
@@ -950,6 +1002,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
}
filterExisted[i] = !isNew
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
}
@@ -984,7 +1037,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
continue
}
s.publishRetainedToClient(cl, sub, existed)
s.publishRetainedToClient(cl, sub, filterExisted[i])
}
return nil
@@ -1034,14 +1087,20 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
return cl.WritePacket(ack)
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *Client) {
for k := range cl.State.Subscriptions.GetAll() {
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) UnsubscribeClient(cl *Client) {
i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k, v := range filterMap {
cl.State.Subscriptions.Delete(k)
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1)
}
filters[i] = v
i++
}
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
}
// processAuth processes an Auth packet.
@@ -1086,9 +1145,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
}
// We already have a code we are using to disconnect the client, so we are not
// interested if the write packet fails due to a closed connection (as we are closing it).
err := cl.WritePacket(out)
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
cl.Stop(code)
if code.Code >= packets.ErrUnspecifiedError.Code {
return code
}
}
return err
@@ -1303,9 +1367,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
// loadClients restores clients from the datastore.
func (s *Server) loadClients(v []storage.Client) {
for _, c := range v {
cl := newClientStub()
cl.ID = c.ID
cl.Net.Listener = c.Listener
cl := s.NewClient(nil, c.Listener, c.ID, false)
cl.Properties.Username = c.Username
cl.Properties.Clean = c.Clean
cl.Properties.ProtocolVersion = c.ProtocolVersion
@@ -1412,8 +1474,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
// clearExpiredInflights deletes any inflight messages which have expired.
func (s *Server) clearExpiredInflights(now int64) {
for _, client := range s.Clients.GetAll() {
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
s.hooks.OnExpireInflights(client, now)
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
for _, id := range deleted {
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package system
// Info contains atomic counters and values for various server statistics

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (