Compare commits

...

91 Commits

Author SHA1 Message Date
mochi-co
ef5dcf68d0 Use context to signal client open state 2023-05-06 11:49:02 +01:00
JB
6704cf7227 Add packet ID exhausted hook (#217) 2023-05-06 10:37:27 +01:00
thedevop
9233e6fd39 Expire session if SessionExpiryInterval is 0 (#216)
If SessionExpiryInterval was not set in CONNECT, SessionExpiryIntervalFlag is also not set. According to spec:
  If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed.
2023-05-06 10:12:33 +01:00
ħþ
1ca65d9631 Update codes.go (#215)
fix typo
2023-05-06 10:02:25 +01:00
ħþ
33229da885 Update codes.go (#214)
Fix typo
2023-05-06 09:59:50 +01:00
mochi-co
c274d5fd08 Update server version 2023-05-05 00:04:27 +01:00
JB
10e82f41d6 Lock on close outbound (#213) 2023-05-05 00:02:49 +01:00
JB
e6c07b2b78 Add lock to client writes (#212) 2023-05-04 23:17:12 +01:00
JB
eed3ef9606 Add OnPacketIDExhausted hook (#211) 2023-05-04 22:51:40 +01:00
JB
1ec880844d Correctly validate WillProperties (#210)
Co-authored-by: sukvojte <sukvojte@gmail.com>
2023-05-04 22:37:23 +01:00
werbenhu
4b49652a8c Update build.yml (#203) 2023-05-04 18:09:58 +01:00
mochi-co
d46e7b5bcf Protect close of nil outbound channel 2023-04-21 22:09:59 +01:00
mochi-co
17fb7dadbc Protect close of nil outbound channel 2023-04-21 22:00:27 +01:00
werben
ed7fd836e1 #78 storage hook should not execute the relevant code if the client has been reconnected (#198)
* storage hook should not execute the relevant code if the client has been reconnected #78

* add test cases for coverage decrease

add test cases for coverage decrease
2023-04-21 21:52:44 +01:00
mochi-co
605bb93c75 Move msgToPacket to storage.Message.ToPacket 2023-04-21 21:49:49 +01:00
Wind
c73ace2ea0 Simplified code (#195)
Simplify the code for the loadInflight and loadRetained methods.
Adjust the validation order of the processSubscribe method to ensure that it fails quickly if there is an error, since s.hooks. OnACLCheck generally takes a long time.

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-04-21 21:29:03 +01:00
thedevop
aac6d699da Ensure to close client WriteLoop (#193)
* Ensure client WriteLoop is closed

* Ensure to close client WriteLoop
2023-04-21 21:20:46 +01:00
Hubertus Hohl
7bd7bd5087 fix: common subscriptions issued by different clients at the same time may be lost (#186) 2023-03-11 23:17:10 +00:00
mochi-co
655bf9fdb1 Update readme 2023-03-11 23:15:51 +00:00
mochi-co
b188055c7d Update server version 2023-03-11 23:14:51 +00:00
JB
aaf1d9d4c6 Configurable client bufio reader/writer sizes (#190) 2023-03-11 23:13:28 +00:00
mochi-co
44ce819318 Update server version 2023-02-28 20:58:05 +00:00
dependabot[bot]
e4c76cc60c Bump golang.org/x/net from 0.0.0-20220927171203-f486391704dc to 0.7.0 (#182)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.0.0-20220927171203-f486391704dc to 0.7.0.
- [Release notes](https://github.com/golang/net/releases)
- [Commits](https://github.com/golang/net/commits/v0.7.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-02-28 20:57:33 +00:00
thedevop
da79faa972 Skip expire cleanup for isTakenOver session (#183)
* Skip expire cleanup for isTakenOver session

* Set prev connection to isTakenOver on CleanSession

#173.
2023-02-28 20:53:26 +00:00
JB
46babc89c8 Allow 0 byte usernames if correctly formed (#181)
* Allow 0 byte usernames if correctly formed

* Allow 0 byte usernames if correctly formed
2023-02-25 01:37:54 +00:00
JB
9b7a943888 Correctly identify and clean taken-over sessions (#180) 2023-02-25 01:24:17 +00:00
mochi-co
a909d30923 Small style fix 2023-02-22 23:33:49 +00:00
mochi-co
0851b09e4d Update server version 2023-02-22 23:06:22 +00:00
thedevop
a302c9dd88 Use *packets.Packet for outbound chan (#176)
* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan
2023-02-22 23:02:44 +00:00
mochi-co
1e8f922102 update server version 2023-02-20 18:14:57 +00:00
Hubertus Hohl
4c16e5593f fix: correct decoding of packets including Properties exceeding 127 bytes in length (#172) 2023-02-20 18:14:19 +00:00
mochi-co
49cada4fbc Update server version 2023-02-10 23:39:27 +00:00
JB
ef34510c0b Expose dropped publish messages count in sys info (#170) 2023-02-10 23:38:20 +00:00
JB
e5716caad1 Fix potential NextPacketID endless loop, expand tests (#169)
* Fix possible NextPacketID endless loop, expand tests

* Optimize NextPacketID

* Use math constants
2023-02-10 23:27:21 +00:00
thedevop
4b039cb35c Add PublishDropped metrics (#167)
* Add PublishDropped

* Add PublishDropped

* Add PublishDropped

* Update storage_test.go

* Update system.go

* Update server.go
2023-02-10 14:44:01 +00:00
JB
aac245441a No longer issue retained messages on session takeover (#166) 2023-02-09 23:57:24 +00:00
JB
bb54cc68e6 Client write buffers (#165)
* Replace fanpool with client write buffers
2023-02-09 22:34:30 +00:00
thedevop
7ba1352a60 Add Clone to system.Info (#163)
* Add Clone using atomic operations

* Add Clone using atomic operations

* Use sysinfo.Clone

* Unit test for Clone

* Add Clone using atomic operations

* Update

* Update
2023-02-09 19:07:17 +00:00
mochi-co
ca849131eb Update server version 2023-02-05 11:07:07 +00:00
Wind
ba7e534122 failed to delete inflight data (#162)
The s.hooks.OnQosPublish method needs to be called, otherwise the following s.hooks.OnQosComplete or processPuback(s.hooks.OnQosComplete) method will report a data not found error.
2023-02-05 10:53:49 +00:00
mochi-co
db760c34a5 Update server version 2023-02-04 10:57:27 +00:00
JB
ae3ee81bb4 Rename Quota methods for clarity (#159) 2023-02-04 10:53:45 +00:00
JB
c2ca02d149 Move refreshDeadline to only trigger on successful transmission (#157) 2023-02-04 10:16:05 +00:00
Jeroen Rinzema
77a64d9c87 Include a listener accepting an existing net.Listener (#155) 2023-02-04 10:10:10 +00:00
Wind
8dec9cc962 invalid config type provided (#152)
* invalid config type provided

examples/persistence/bolt/main.go: invalid config type provided

* fixed ErrReceiveMaximum(receive maximum exceeded)

No quotas of the inflight is set in the readStore method, so each quota is equal to 0. The inheritClientSession method overrides the quotas of the new client inflight, so the processPublish method reports an ErrReceiveMaximum and disconnects the client.

* reset receive quota

receive quota should be reset across connections (as specified in the spec).
2023-02-04 10:06:26 +00:00
mochi-co
f90e52328d Update server version 2023-01-16 20:08:55 +00:00
JB
50aae47618 Publish retained messages only after connack (#147) 2023-01-16 19:50:01 +00:00
JB
0d79f2d63b Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
2023-01-16 19:49:36 +00:00
JB
300152413c Ignore retain as published v3 (#142)
* Optimise Capabilities struct alignment

* Only use RetainAsPublished for v5 clients
2023-01-13 23:38:49 +00:00
mochi-co
0de1d731db Update version number 2023-01-10 00:01:21 +00:00
JB
80746abc52 Use correct connack return codes for MQTTv3 (#140) 2023-01-10 00:00:43 +00:00
mochi-co
a73cf4ca0e Update server version 2023-01-09 23:08:49 +00:00
mochi-co
bc549ee7ed Fix example imports 2023-01-09 22:52:24 +00:00
mochi-co
c464b46713 export client.Net.Conn for external use 2023-01-09 22:49:40 +00:00
mochi-co
05ce56008c Small code improvements 2023-01-09 22:49:20 +00:00
JB
8254cb0cbc Make hooks safe for concurrency (#139)
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
2023-01-09 22:41:44 +00:00
mochi-co
4ae58b79e3 Update server version 2023-01-07 20:13:48 +00:00
thedevop
b895d688e0 Change inline check order (#133) 2023-01-07 20:02:05 +00:00
mochi-co
a600cd4ead fix grammar on Closed method doc 2023-01-07 17:57:04 +00:00
mochi-co
cdb44990cf Update version number 2023-01-07 17:30:58 +00:00
mochi-co
2d9c128111 Refactor stored subscription value assignments 2023-01-07 17:30:30 +00:00
mochi-co
a0d5bdb39f Fix Typos 2023-01-07 17:30:01 +00:00
Wind
4ebcef3cb6 Save subscription properties for mqttv5 (#131)
* Update redis.go

Save the subscription properties for mqqtv5

* Update badger.go

Save the subscription properties for mqqtv5

* Update bolt.go

Save the subscription properties for mqqtv5

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-01-07 17:24:23 +00:00
thedevop
fb8d4720d7 Add Client Closed (#130)
* Add Client Closed
* Add Client Closed
* Update clients_test.go
2023-01-07 17:14:51 +00:00
JB
4080c89127 Update README.md 2022-12-21 21:00:53 +00:00
JB
1b67e6f3f6 Update README.md 2022-12-21 20:58:25 +00:00
mochi-co
1adb02e087 Update readme and server version 2022-12-21 20:47:58 +00:00
JB
4d4140aa99 Connect ReturnResponseInfo only applies to Connack values (#128) 2022-12-21 20:37:08 +00:00
JB
e31840a37d Optimize inflight expiry (#127)
* Small formatting/style changes for filter existed

* Use OnQoSDropped hook instead of onInflightExpired
2022-12-21 19:44:25 +00:00
JB
7d2e16f2d3 Merge pull request #123 from wind-c/master
Variable existed in the method processSubscribe is unstable
2022-12-21 11:41:14 +00:00
JB
92cd935a16 Merge branch 'master' into master 2022-12-21 11:38:28 +00:00
JB
25ce27ce2d Merge pull request #124 from zgwit/master
Add unix socket listener
2022-12-21 11:28:23 +00:00
jason
527d084a4b Add unix socket listener 2022-12-20 23:02:59 +08:00
Wind
bb9f937bb0 Variable existed in the method processSubscribe is unstable
The variable existed can be changed repeatedly within a for loop. An array variable must be used to record the subscription of each filter.
2022-12-18 13:46:06 +08:00
Wind
511fe88684 Merge branch 'mochi-co:master' into master 2022-12-17 12:33:09 +08:00
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
“Wind”
d06f47f4b9 Add the OnUnsubscribed hook to the unsubscribeClient method
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-17 00:40:06 +08:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
JB
d8f28cb843 Enforce server max packet (#121)
* Enforce Server Maximum Packet Size on client read
* Fix tests
2022-12-16 00:30:23 +00:00
JB
88861c219d Merge pull request #116 from tommyminds/bugfix/ws_malformed_package
Fix websocket malformed packet bug
2022-12-15 18:21:53 +00:00
JB
7ba6cf28d9 Merge branch 'master' into bugfix/ws_malformed_package 2022-12-15 18:21:33 +00:00
JB
c174cfdc6b Merge pull request #119 from mochi-co/fix-on-published
Fix mis-typed onpublished hook, update version, fanpool defaults
2022-12-15 18:21:19 +00:00
mochi-co
4f198a99dd Fix mis-typed onpublished hook, update version, fanpool defaults 2022-12-15 18:19:02 +00:00
Tommy Maintz
2a9c9fcc40 Fix websocket malformed packet bug 2022-12-14 21:41:33 +01:00
JB
835a85c8bf Update README.md 2022-12-12 11:44:36 +00:00
mochi-co
fe5d9ffa61 Simplify Client construction, add NewClient method to Server, add Publish convenience method 2022-12-12 11:37:19 +00:00
mochi-co
aac186dcc1 Add newline for godoc formatting 2022-12-11 22:25:21 +00:00
JB
42931f332f Update badges to use v2 references 2022-12-11 21:44:44 +00:00
mochi-co
8a04648c09 Cleanup godoc formatting 2022-12-11 21:38:01 +00:00
JB
854c033fb6 Update README.md 2022-12-11 12:21:25 +00:00
264 changed files with 23928 additions and 3686 deletions

View File

@@ -7,37 +7,39 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
name: Test with Coverage
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Check out code
uses: actions/checkout@v3
- name: Install dependencies
run: |
go mod download
- name: Run Unit tests
run: |
go test -race -covermode atomic -coverprofile=covprofile ./...
- name: Install goveralls
run: go install github.com/mattn/goveralls@latest
- name: Send coverage
env:
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: goveralls -coverprofile=covprofile -service=github

View File

@@ -2,9 +2,9 @@
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master&v2)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
</p>
@@ -19,6 +19,11 @@ MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It
## What's new in Version 2.0.0?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
Don't forget to use the new v2 import paths:
```go
import "github.com/mochi-co/mqtt/v2"
```
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
- User and MQTTv5 Packet Properties
- Topic Aliases
@@ -36,10 +41,10 @@ Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learn
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
@@ -78,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"log"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
if err != nil {
@@ -107,10 +116,15 @@ Examples of running the broker with various configurations can be found in the [
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
| Listener | Usage |
| --- | --- |
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewNet | A net.Listener listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
A `*listeners.Config` may be passed to configure TLS.
@@ -122,6 +136,8 @@ A number of configurable options are available which can be used to alter the be
```go
server := mqtt.New(&mqtt.Options{
Capabilities: mqtt.Capabilities{
ClientNetWriteBufferSize: 4096,
ClientNetReadBufferSize: 4096,
MaximumSessionExpiryInterval: 3600,
Compatibilities: mqtt.Compatibilities{
ObscureNotAuthorized: true,
@@ -131,7 +147,7 @@ server := mqtt.New(&mqtt.Options{
})
```
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
## Event Hooks
@@ -283,15 +299,16 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
@@ -300,13 +317,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Packet Injection
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
### Direct Publish
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
```go
cl := mqtt.NewInlineClient("inline", "local")
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
```
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
### Packet Injection
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
```go
cl := server.NewClient(nil, "local", "inline", true)
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
@@ -320,6 +346,8 @@ server.InjectPacket(cl, packets.Packet{
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
@@ -344,23 +372,23 @@ Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inov
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |

View File

@@ -1,11 +1,13 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -86,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -105,7 +107,7 @@ type Client struct {
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
@@ -134,30 +136,31 @@ type Will struct {
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, o *ops) *Client {
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
cl := &Client{
Net: ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
open: context.Background(),
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -165,43 +168,29 @@ func NewClient(c net.Conn, o *ops) *Client {
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// NewInlineClient returns a client used when publishing from the embedding system.
func NewInlineClient(id, remote string) *Client {
return &Client{
ID: id,
Net: ClientConnection{
Remote: remote,
Inline: true,
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// newClientStub returns an instance of Client with minimal initializations, such as
// restoring client data from a db. In particular, the client is marked as offline (done).
func newClientStub() *Client {
return &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
done: 1,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
@@ -214,8 +203,8 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
@@ -225,7 +214,7 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
@@ -253,13 +242,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -267,28 +256,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
started := i
overflowed := false
for {
if i >= 65535 {
overflowed = true
i = 1
} else {
i++
}
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
break
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
@@ -302,7 +293,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
@@ -320,17 +311,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
@@ -340,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
@@ -365,20 +357,28 @@ func (cl *Client) Read(packetHandler ReadFn) error {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
cl.Lock()
defer cl.Unlock()
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.outbound != nil {
close(cl.State.outbound)
}
if cl.State.open != nil {
var cancel context.CancelFunc
cl.State.open, cancel = context.WithCancel(cl.State.open)
cancel()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -391,6 +391,11 @@ func (cl *Client) StopCause() error {
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
@@ -413,6 +418,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
@@ -480,15 +489,6 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
return nil
}
defer cl.refreshDeadline(cl.State.keepalive)
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
@@ -502,8 +502,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
@@ -552,8 +552,19 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
cl.Lock()
defer cl.Unlock()
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {
return err
}

View File

@@ -1,9 +1,11 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -21,16 +23,20 @@ const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newClient() (cl *Client, r net.Conn, w net.Conn) {
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = NewClient(w, &ops{
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
@@ -41,6 +47,9 @@ func newClient() (cl *Client, r net.Conn, w net.Conn) {
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
@@ -106,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -118,34 +127,23 @@ func TestClientsGetByListener(t *testing.T) {
}
func TestNewClient(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Nil(t, cl.StopCause())
}
func TestNewClientStub(t *testing.T) {
cl := newClientStub()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
}
func TestNewInlineClient(t *testing.T) {
cl := NewInlineClient("inline", "local")
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
require.Equal(t, "inline", cl.ID)
require.Equal(t, "local", cl.Net.Remote)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -175,14 +173,14 @@ func TestClientParseConnect(t *testing.T) {
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -207,13 +205,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
@@ -225,7 +223,7 @@ func TestClientNextPacketID(t *testing.T) {
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
@@ -248,33 +246,37 @@ func TestClientNextPacketIDInUse(t *testing.T) {
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newClient()
cl.State.packetID = uint32(65534)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(65535), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
@@ -284,13 +286,15 @@ func TestClientClearInflights(t *testing.T) {
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newClient()
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
@@ -310,7 +314,7 @@ func TestClientResendInflightMessages(t *testing.T) {
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newClient()
cl, r, _ := newTestClient()
r.Close()
cl.State.Inflight.Set(*pk1.Packet)
@@ -322,19 +326,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -349,7 +353,7 @@ func TestClientReadFixedHeader(t *testing.T) {
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -362,8 +366,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -377,7 +397,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
@@ -391,7 +411,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -445,9 +465,9 @@ func TestClientReadOK(t *testing.T) {
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.open = nil
o := make(chan error)
go func() {
@@ -460,15 +480,23 @@ func TestClientReadDone(t *testing.T) {
}
func TestClientStop(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -485,7 +513,7 @@ func TestClientReadFixedHeaderError(t *testing.T) {
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -505,7 +533,7 @@ func TestClientReadReadHandlerErr(t *testing.T) {
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -537,7 +565,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
@@ -570,9 +598,17 @@ func TestClientReadPacket(t *testing.T) {
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
@@ -588,7 +624,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -612,7 +648,7 @@ func TestClientWritePacket(t *testing.T) {
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
@@ -621,7 +657,7 @@ func TestWriteClientOversizePacket(t *testing.T) {
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -641,7 +677,7 @@ func TestClientReadPacketReadingError(t *testing.T) {
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
@@ -660,7 +696,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
@@ -669,15 +705,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newClient()
cl.Net.conn.Close()
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -35,7 +36,7 @@ func main() {
}
err = server.AddHook(new(debug.Hook), &debug.Options{
ShowPacketData: true,
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -51,15 +52,30 @@ func main() {
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := mqtt.NewInlineClient("inline", "local")
for range time.Tick(time.Second * 10) {
server.InjectPacket(cl, packets.Packet{
cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 1) {
err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
Payload: []byte("injected scheduled message"),
})
if err != nil {
server.Log.Error().Err(err).Msg("server.InjectPacket")
}
server.Log.Info().Msgf("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error().Err(err).Msg("server.Publish")
}
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
}
}()

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -28,7 +29,6 @@ func main() {
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -29,12 +30,15 @@ func main() {
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), bolt.Options{
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (

View File

@@ -1,100 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

View File

@@ -1,88 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

6
go.mod
View File

@@ -6,7 +6,6 @@ require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
@@ -21,6 +20,7 @@ require (
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
@@ -33,8 +33,8 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.1 // indirect
)

10
go.sum
View File

@@ -109,8 +109,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -118,11 +118,11 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=

152
hooks.go
View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
import (
@@ -38,15 +39,16 @@ const (
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnQosPublish
OnQosComplete
OnQosDropped
OnPacketIDExhausted
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
@@ -87,15 +89,16 @@ type Hook interface {
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnPacketIDExhausted(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
@@ -111,10 +114,10 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -124,7 +127,7 @@ func (h *Hooks) Len() int64 {
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
@@ -139,26 +142,39 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
@@ -173,7 +189,7 @@ func (h *Hooks) Stop() {
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
@@ -182,7 +198,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
@@ -191,7 +207,7 @@ func (h *Hooks) OnStarted() {
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
@@ -200,7 +216,7 @@ func (h *Hooks) OnStopped() {
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
@@ -209,7 +225,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
@@ -218,7 +234,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
@@ -228,7 +244,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -249,7 +265,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
@@ -265,7 +281,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
@@ -276,7 +292,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
@@ -286,7 +302,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
@@ -298,7 +314,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
@@ -308,7 +324,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
@@ -320,7 +336,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
@@ -333,7 +349,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
@@ -343,19 +359,19 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnMessage
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -374,16 +390,26 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
@@ -394,7 +420,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
@@ -405,7 +431,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
@@ -413,22 +439,32 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
// assign to a packet.
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketIDExhausted) {
hook.OnPacketIDExhausted(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
@@ -444,7 +480,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
@@ -453,7 +489,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
@@ -462,7 +498,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
@@ -472,7 +508,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
@@ -492,7 +528,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
@@ -512,7 +548,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
@@ -532,7 +568,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
@@ -551,7 +587,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
@@ -573,7 +609,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
@@ -589,7 +625,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
@@ -600,19 +636,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
@@ -728,6 +751,9 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
@@ -740,6 +766,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
@@ -754,9 +783,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
@@ -79,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -182,6 +182,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
@@ -198,11 +202,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
@@ -347,32 +355,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
@@ -381,6 +370,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")

View File

@@ -1,16 +1,15 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"errors"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
@@ -169,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -218,6 +232,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -332,6 +369,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -418,48 +470,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -185,6 +184,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
@@ -201,12 +204,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
@@ -369,34 +377,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
@@ -404,6 +391,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -1,10 +1,10 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"errors"
"os"
"testing"
"time"
@@ -211,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -226,6 +241,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -340,6 +378,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -426,48 +479,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
@@ -82,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -199,6 +199,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
@@ -215,11 +219,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
@@ -363,37 +371,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
@@ -402,6 +386,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
@@ -252,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -268,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -391,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -483,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
@@ -116,6 +117,36 @@ func (d *Message) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, d)
}
// ToPacket converts a storage.Message to a standard packet.
func (d *Message) ToPacket() packets.Packet {
pk := packets.Packet{
FixedHeader: d.FixedHeader,
PacketID: d.PacketID,
TopicName: d.TopicName,
Payload: d.Payload,
Origin: d.Origin,
Created: d.Created,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
}
// Return a deep copy of the packet data otherwise the slices will
// continue pointing at the values from the storage packet.
pk = pk.Copy(true)
pk.FixedHeader.Dup = d.FixedHeader.Dup
return pk
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
@@ -103,6 +104,7 @@ var (
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
@@ -110,7 +112,7 @@ var (
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
@@ -192,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}
func TestMessageToPacket(t *testing.T) {
d := messageStruct
pk := d.ToPacket()
require.Equal(t, packets.Packet{
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: d.FixedHeader.Remaining,
Type: d.FixedHeader.Type,
Qos: d.FixedHeader.Qos,
Dup: d.FixedHeader.Dup,
Retain: d.FixedHeader.Retain,
},
Origin: d.Origin,
TopicName: d.TopicName,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
PacketID: 100,
Created: d.Created,
}, pk)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -26,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -177,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
@@ -223,14 +236,15 @@ func TestHooksNonReturns(t *testing.T) {
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnPacketIDExhausted(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -57,6 +58,18 @@ func (i *Inflight) Len() int {
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
@@ -103,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool {
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) TakeReceiveQuota() {
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) ReturnReceiveQuota() {
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
@@ -122,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// TakeSendQuota reduces the send quota by 1.
func (i *Inflight) TakeSendQuota() {
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// ReturnSendQuota increases the send quota by 1.
func (i *Inflight) ReturnSendQuota() {
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -12,7 +13,7 @@ import (
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
@@ -24,7 +25,7 @@ func TestInflightSet(t *testing.T) {
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
@@ -33,7 +34,7 @@ func TestInflightGet(t *testing.T) {
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
@@ -55,13 +56,23 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
@@ -94,12 +105,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
@@ -109,12 +120,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
@@ -136,12 +147,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
@@ -151,18 +162,18 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -106,28 +107,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := &system.Info{
Version: l.sysInfo.Version,
Started: atomic.LoadInt64(&l.sysInfo.Started),
Time: atomic.LoadInt64(&l.sysInfo.Time),
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
}
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

88
listeners/net.go Normal file
View File

@@ -0,0 +1,88 @@
package listeners
import (
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *zerolog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

105
listeners/net_test.go Normal file
View File

@@ -0,0 +1,105 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

98
listeners/unixsock.go Normal file
View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -1,11 +1,13 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
@@ -136,25 +138,35 @@ type wsConn struct {
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
return 0, err
}
return r.Read(p)
var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
return 0, err
}
return len(p), nil

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -375,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(65535)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
@@ -20,7 +21,7 @@ func (c Code) Error() string {
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
@@ -112,15 +113,35 @@ var (
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
@@ -14,22 +16,23 @@ import (
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
WillProperties byte = 99 // Special byte for validating Will Properties.
)
var (
@@ -207,7 +210,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
Created: pk.Created,
Expiry: pk.Expiry,
Origin: pk.Origin,
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
}
if allowTransfer {
p.PacketID = pk.PacketID
}
if len(pk.Connect.ProtocolName) > 0 {
@@ -308,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Properties).Encode(pk, pb, 0)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -317,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -378,7 +384,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
@@ -388,11 +394,11 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
if pk.ProtocolVersion == 5 {
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
if err != nil {
return ErrMalformedWillProperties
}
offset += n + 1
offset += n
}
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
@@ -407,6 +413,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
if offset >= len(buf) { // we are at the end of the packet
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
pk.Connect.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
@@ -438,18 +448,14 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
}
if len(pk.Connect.Password) > 65535 {
if len(pk.Connect.Password) > math.MaxUint16 {
return ErrProtocolViolationPasswordTooLong
}
if len(pk.Connect.Username) > 65535 {
if len(pk.Connect.Username) > math.MaxUint16 {
return ErrProtocolViolationUsernameTooLong
}
if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 {
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
}
@@ -462,7 +468,7 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
}
if len(pk.Connect.ClientIdentifier) > 65535 {
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
return ErrClientIdentifierNotValid
}
@@ -491,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
@@ -534,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -603,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}
@@ -639,7 +645,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Payload = buf[offset:]
@@ -687,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
}
@@ -828,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
@@ -856,7 +862,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.ReasonCodes = buf[offset:]
@@ -885,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -913,7 +919,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -980,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -1009,7 +1015,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
pk.ReasonCodes = buf[offset:]
}
@@ -1033,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -1061,7 +1067,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -1096,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
pk.FixedHeader.Remaining = nb.Len()

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -463,6 +464,9 @@ func TestCopy(t *testing.T) {
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -41,11 +42,11 @@ const (
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
@@ -53,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
@@ -63,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
@@ -193,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
@@ -216,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
@@ -276,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
@@ -288,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
@@ -321,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
@@ -330,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes())
}
}
@@ -360,18 +359,19 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
n, _, err = DecodeLength(b)
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n, err
return n + bu, err
}
if n == 0 {
return n, nil
return n + bu, nil
}
bt := b.Bytes()
@@ -379,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
}
switch k {
@@ -405,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n, err
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
@@ -451,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
@@ -469,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
}
if err != nil {
return n, err
return n + bu, err
}
}
return n, nil
return n + bu, nil
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -201,14 +202,14 @@ func init() {
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
@@ -218,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
@@ -231,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
@@ -239,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
@@ -249,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172, n)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// TPacketCase contains data for cross-checking the encoding and decoding
@@ -70,7 +71,7 @@ const (
TConnectInvalidWillFlagNoPayload
TConnectInvalidWillFlagQosOutOfRange
TConnectInvalidWillSurplusRetain
TConnectNotCleanNoClientID
TConnectZeroByteUsername
TConnectSpecInvalidUTF8D800
TConnectSpecInvalidUTF8DFFF
TConnectSpecInvalidUTF80000
@@ -88,6 +89,7 @@ const (
TConnackServerUnavailable
TConnackBadUsernamePassword
TConnackBadUsernamePasswordNoSession
TConnackMqtt5BadUsernamePasswordNoSession
TConnackNotAuthorised
TConnackMalSessionPresent
TConnackMalReturnCode
@@ -248,26 +250,26 @@ var TPacketData = map[byte]TPacketCases{
Desc: "mqtt v3.1.1",
Primary: true,
RawBytes: []byte{
Connect << 4, 16, // Fixed header
Connect << 4, 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 60, // Keepalive
0, 4, // Client ID - MSB+LSB
'z', 'e', 'n', '3', // Client ID "zen"
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 16,
Remaining: 15,
},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: false,
Keepalive: 60,
ClientIdentifier: "zen3",
ClientIdentifier: "zen",
},
},
},
@@ -424,9 +426,9 @@ var TPacketData = map[byte]TPacketCases{
Connect << 4, 28, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
194, // Packet Flags
0, 20, // Keepalive
4, // Protocol Version
0 | 1<<6 | 1<<7, // Packet Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 5, // Username MSB+LSB
@@ -442,7 +444,7 @@ var TPacketData = map[byte]TPacketCases{
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Clean: false,
Keepalive: 20,
ClientIdentifier: "zen",
UsernameFlag: true,
@@ -496,6 +498,43 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectZeroByteUsername,
Desc: "username flag but 0 byte username",
Group: "decode",
RawBytes: []byte{
Connect << 4, 23, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Packet Flags
0, 30, // Keepalive
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 0, // Username MSB+LSB
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 23,
},
ProtocolVersion: 5,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Keepalive: 30,
ClientIdentifier: "zen",
Username: []byte{},
UsernameFlag: true,
},
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
},
},
},
// Fail States
{
@@ -622,6 +661,24 @@ var TPacketData = map[byte]TPacketCases{
'm', 'o', 'c',
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "username flag with no username bytes",
Group: "decode",
FailFirst: ErrProtocolViolationFlagNoUsername,
RawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Flags
0, 20, // Keepalive
0,
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
},
{
Case: TConnectMalPassword,
Desc: "malformed password",
@@ -782,20 +839,6 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "has username flag but no username",
Group: "validate",
Expect: ErrProtocolViolationFlagNoUsername,
Packet: &Packet{
FixedHeader: FixedHeader{Type: Connect},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
UsernameFlag: true,
},
},
},
{
Case: TConnectInvalidUsernameNoFlag,
Desc: "has username but no flag",
@@ -1315,10 +1358,28 @@ var TPacketData = map[byte]TPacketCases{
Desc: "bad username or password no session",
RawBytes: []byte{
Connack << 4, 2, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0, // No session present
Err3NotAuthorized.Code, // use v3 remapping
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
},
ReasonCode: Err3NotAuthorized.Code,
},
},
{
Case: TConnackMqtt5BadUsernamePasswordNoSession,
Desc: "mqtt5 bad username or password no session",
RawBytes: []byte{
Connack << 4, 3, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
@@ -1326,6 +1387,7 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrBadUsernameOrPassword.Code,
},
},
{
Case: TConnackNotAuthorised,
Desc: "not authorised",
@@ -1803,13 +1865,10 @@ var TPacketData = map[byte]TPacketCases{
Case: TPublishRetainMqtt5,
Desc: "retain mqtt5",
RawBytes: []byte{
Publish<<4 | 1<<0, 35, // Fixed header
Publish<<4 | 1<<0, 19, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
16, // properties length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
0, // properties length
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
@@ -1817,18 +1876,11 @@ var TPacketData = map[byte]TPacketCases{
FixedHeader: FixedHeader{
Type: Publish,
Retain: true,
Remaining: 35,
Remaining: 19,
},
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
TopicName: "a/b/c",
Properties: Properties{},
Payload: []byte("hello mochi"),
},
},
{

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

345
server.go
View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
package mqtt
@@ -25,10 +26,8 @@ import (
)
const (
Version = "2.0.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
Version = "2.2.8" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
var (
@@ -46,6 +45,7 @@ var (
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
@@ -55,18 +55,20 @@ var (
// Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct {
MaximumMessageExpiryInterval int64
MaximumClientWritesPending int32
MaximumSessionExpiryInterval uint32
MaximumPacketSize uint32
maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
MaximumQos byte
RetainAvailable byte
WildcardSubAvailable byte
SubIDAvailable byte
SharedSubAvailable byte
MinimumProtocolVersion byte
}
// Compatibilities provides flags for using compatibility modes.
@@ -79,9 +81,17 @@ type Compatibilities struct {
// Options contains configurable options for the server.
type Options struct {
// Capabilities defines the server features and behaviour.
// Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
ClientNetWriteBufferSize int
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
ClientNetReadBufferSize int
// Logger specifies a custom configured implementation of zerolog to override
// the servers default logger configuration. If you wish to change the log level,
// of the default logger, you can do so by setting
@@ -90,16 +100,6 @@ type Options struct {
// server.Log = &l
Logger *zerolog.Logger
// FanPoolSize is the number of individual workers and queues to initialize.
// Bigger is not necessarily better, and you should rely on defaults unless
// you have know what you are doing.
FanPoolSize uint64
// FanPoolQueueSize is the size of the queue per worker. Increase this value
// accordingly if you anticipate having intermittent but massive numbers of
// messages. Cluster support is roadmapped.
FanPoolQueueSize uint64
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64
}
@@ -112,7 +112,6 @@ type Server struct {
Clients *Clients // clients known to the broker
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
Info *system.Info // values about the server commonly known as $SYS topics
fanpool *FanPool // a fixed size worker pool for processing inbound and outbound messages
loop *loop // loop contains tickers for the system event loop
done chan bool // indicate that the server is ending
Log *zerolog.Logger // minimal no-alloc logger
@@ -131,10 +130,10 @@ type loop struct {
// ops contains server values which can be propagated to other structs.
type ops struct {
capabilities *Capabilities // a pointer to the server capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
options *Options // a pointer to the server options and capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
}
// New returns a new instance of mochi mqtt broker. Optional parameters
@@ -164,8 +163,7 @@ func New(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize),
Log: opts.Logger,
Log: opts.Logger,
hooks: &Hooks{
Log: opts.Logger,
},
@@ -180,16 +178,18 @@ func (o *Options) ensureDefaults() {
o.Capabilities = DefaultServerCapabilities
}
o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
if o.SysTopicResendInterval == 0 {
o.SysTopicResendInterval = defaultSysTopicInterval
}
if o.FanPoolSize == 0 {
o.FanPoolSize = defaultFanPoolSize
if o.ClientNetWriteBufferSize == 0 {
o.ClientNetWriteBufferSize = 1024 * 2
}
if o.FanPoolQueueSize < 1 {
o.FanPoolQueueSize = defaultFanPoolQueueSize
if o.ClientNetReadBufferSize == 0 {
o.ClientNetReadBufferSize = 1024 * 2
}
if o.Logger == nil {
@@ -198,6 +198,33 @@ func (o *Options) ensureDefaults() {
}
}
// NewClient returns a new Client instance, populated with all the required values and
// references to be used with the server. If you are using this client to directly publish
// messages from the embedding application, set the inline flag to true to bypass ACL and
// topic validation checks.
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
options: s.Options,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
cl.ID = id
cl.Net.Listener = listener
if inline { // inline clients bypass acl and some validity checks.
cl.Net.Inline = true
// By default we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
} else {
go cl.WriteLoop() // can only write to real clients
}
return cl
}
// AddHook attaches a new Hook to the server. Ideally, this should be called
// before the server is started with s.Serve().
func (s *Server) AddHook(hook Hook, config any) error {
@@ -280,27 +307,21 @@ func (s *Server) eventLoop() {
}
// EstablishConnection establishes a new client when a listener accepts a new connection.
func (s *Server) EstablishConnection(lid string, c net.Conn) error {
cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
return s.attachClient(cl, lid)
func (s *Server) EstablishConnection(listener string, c net.Conn) error {
cl := s.NewClient(c, listener, "", false)
return s.attachClient(cl, listener)
}
// attachClient validates an incoming client connection and if viable, attaches the client
// to the server, performs session housekeeping, and reads incoming packets.
func (s *Server) attachClient(cl *Client, lid string) error {
func (s *Server) attachClient(cl *Client, listener string) error {
defer cl.Stop(nil)
pk, err := s.readConnectionPacket(cl)
if err != nil {
return fmt.Errorf("read connection: %w", err)
}
cl.ParseConnect(lid, pk)
cl.ParseConnect(listener, pk)
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
if code != packets.CodeSuccess {
if err := s.sendConnack(cl, code, false); err != nil {
@@ -346,18 +367,17 @@ func (s *Server) attachClient(cl *Client, lid string) error {
if err != nil {
s.sendLWT(cl)
cl.Stop(err)
}
if err == nil {
} else {
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
}
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
if expire {
s.unsubscribeClient(cl)
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
cl.ClearInflights(math.MaxInt64, 0)
s.UnsubscribeClient(cl)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
}
@@ -431,25 +451,40 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.unsubscribeClient(existing)
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3]
atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
return false // [MQTT-3.2.2-3]
}
atomic.StoreUint32(&existing.State.isTakenOver, 1)
if existing.State.Inflight.Len() > 0 {
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
}
}
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
for _, sub := range existing.State.Subscriptions.GetAll() {
existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub)
s.publishRetainedToClient(cl, sub, existed)
}
// Clean the state of the existing client to prevent sequential take-overs
// from increasing memory usage by inflights + subs * client-id.
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
s.Log.Debug().Str("client", cl.ID).
Str("old_remote", existing.Net.Remote).
Str("new_remote", cl.Net.Remote).
Msg("session taken over")
return true // [MQTT-3.2.2-3]
}
@@ -469,6 +504,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
}
if reason.Code >= packets.ErrUnspecifiedError.Code {
if cl.Properties.ProtocolVersion < 5 {
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
reason = v3reason
}
}
properties.ReasonString = reason.Reason
ack := packets.Packet{
FixedHeader: packets.FixedHeader{
@@ -568,7 +609,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(next.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.TakeSendQuota()
cl.State.Inflight.DecreaseSendQuota()
}
}
@@ -591,6 +632,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
})
}
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
cl := s.NewClient(nil, "local", "inline", true)
return s.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Qos: qos,
Retain: retain,
},
TopicName: topic,
Payload: payload,
PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
})
}
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
@@ -611,7 +670,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
// processPublish processes a Publish packet.
func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) {
return nil
}
@@ -619,20 +678,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8]
}
if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
return nil
}
pk.Origin = cl.ID
pk.Created = time.Now().Unix()
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
if !cl.Net.Inline {
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
}
}
}
@@ -655,14 +716,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
}
if pk.FixedHeader.Qos == 0 {
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
cl.State.Inflight.TakeReceiveQuota()
cl.State.Inflight.DecreaseReceiveQuota()
ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4]
if pk.FixedHeader.Qos == 2 {
ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8]
@@ -670,6 +729,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Set(ack); ok {
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
}
err := cl.WritePacket(ack)
@@ -681,15 +741,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(ack.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.ReturnReceiveQuota()
s.hooks.OnQosComplete(cl, pk)
cl.State.Inflight.IncreaseReceiveQuota()
s.hooks.OnQosComplete(cl, ack)
}
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.hooks.OnPublish(cl, pk)
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
@@ -733,13 +790,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
}
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
if sub.NoLocal && pk.Origin == cl.ID {
return pk, nil // [MQTT-3.8.3-3]
}
out = pk.Copy(false)
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out := pk.Copy(false)
if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
}
@@ -769,6 +826,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if out.FixedHeader.Qos > 0 {
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
if err != nil {
s.hooks.OnPacketIDExhausted(cl, pk)
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
return out, packets.ErrQuotaExceeded
}
@@ -779,22 +837,32 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3]
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, out, out.Created, 0)
cl.State.Inflight.DecreaseSendQuota()
}
if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
out.Expiry = -1
cl.State.Inflight.Set(out)
return pk, nil
return out, nil
}
}
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return pk, packets.CodeDisconnect
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}
cl.State.Inflight.TakeSendQuota()
select {
case cl.State.outbound <- &out:
atomic.AddInt32(&cl.State.outboundQty, 1)
default:
atomic.AddInt64(&s.Info.MessagesDropped, 1)
cl.ops.hooks.OnPublishDropped(cl, pk)
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
cl.State.Inflight.IncreaseSendQuota()
return out, packets.ErrPendingClientWritesExceeded
}
return out, cl.WritePacket(out)
return out, nil
}
func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
@@ -809,7 +877,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
_, err := s.publishToClient(cl, sub, pkv)
if err != nil {
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
}
}
}
@@ -843,7 +911,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error {
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
cl.State.Inflight.ReturnSendQuota()
cl.State.Inflight.IncreaseSendQuota()
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
}
@@ -866,7 +934,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error {
}
ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6]
cl.State.Inflight.TakeReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5]
return cl.WritePacket(ack)
}
@@ -893,8 +961,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
return err
}
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12]
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -906,8 +974,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server.
func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error {
// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas.
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -924,24 +992,24 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
code = packets.ErrPacketIdentifierInUse
}
existed := false
filterExisted := make([]bool, len(pk.Filters))
reasonCodes := make([]byte, len(pk.Filters))
for i, sub := range pk.Filters {
if code != packets.CodeSuccess {
reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91
continue
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else if !s.hooks.OnACLCheck(cl, sub.Filter, false) {
reasonCodes[i] = packets.ErrNotAuthorized.Code
if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized {
reasonCodes[i] = packets.ErrUnspecifiedError.Code
}
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else {
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if isNew {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
@@ -950,6 +1018,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
}
filterExisted[i] = !isNew
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
}
@@ -984,7 +1053,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
continue
}
s.publishRetainedToClient(cl, sub, existed)
s.publishRetainedToClient(cl, sub, filterExisted[i])
}
return nil
@@ -1034,20 +1103,32 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
return cl.WritePacket(ack)
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *Client) {
for k := range cl.State.Subscriptions.GetAll() {
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) UnsubscribeClient(cl *Client) {
i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k := range filterMap {
cl.State.Subscriptions.Delete(k)
}
if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
return
}
for k, v := range filterMap {
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1)
}
filters[i] = v
i++
}
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
}
// processAuth processes an Auth packet.
func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
_, err := s.hooks.OnAuthPacket(cl, pk)
fmt.Println("err", err)
if err != nil {
return err
}
@@ -1086,9 +1167,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
}
// We already have a code we are using to disconnect the client, so we are not
// interested if the write packet fails due to a closed connection (as we are closing it).
err := cl.WritePacket(out)
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
cl.Stop(code)
if code.Code >= packets.ErrUnspecifiedError.Code {
return code
}
}
return err
@@ -1130,6 +1216,7 @@ func (s *Server) publishSysTopics() {
SysPrefix + "/broker/packets/sent": AtomicItoa(&s.Info.PacketsSent),
SysPrefix + "/broker/messages/received": AtomicItoa(&s.Info.MessagesReceived),
SysPrefix + "/broker/messages/sent": AtomicItoa(&s.Info.MessagesSent),
SysPrefix + "/broker/messages/dropped": AtomicItoa(&s.Info.MessagesDropped),
SysPrefix + "/broker/messages/inflight": AtomicItoa(&s.Info.Inflight),
SysPrefix + "/broker/retained": AtomicItoa(&s.Info.Retained),
SysPrefix + "/broker/subscriptions": AtomicItoa(&s.Info.Subscriptions),
@@ -1151,8 +1238,6 @@ func (s *Server) publishSysTopics() {
func (s *Server) Close() error {
close(s.done)
s.Listeners.CloseAll(s.closeListenerClients)
s.fanpool.Close()
s.fanpool.Wait()
s.hooks.OnStopped()
s.hooks.Stop()
@@ -1272,6 +1357,7 @@ func (s *Server) loadServerInfo(v system.Info) {
atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected)
atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived)
atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent)
atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped)
atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived)
atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent)
atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped)
@@ -1303,9 +1389,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
// loadClients restores clients from the datastore.
func (s *Server) loadClients(v []storage.Client) {
for _, c := range v {
cl := newClientStub()
cl.ID = c.ID
cl.Net.Listener = c.Listener
cl := s.NewClient(nil, c.Listener, c.ID, false)
cl.Properties.Username = c.Username
cl.Properties.Clean = c.Clean
cl.Properties.ProtocolVersion = c.ProtocolVersion
@@ -1331,25 +1415,7 @@ func (s *Server) loadClients(v []storage.Client) {
func (s *Server) loadInflight(v []storage.Message) {
for _, msg := range v {
if client, ok := s.Clients.Get(msg.Origin); ok {
client.State.Inflight.Set(packets.Packet{
FixedHeader: msg.FixedHeader,
PacketID: msg.PacketID,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
client.State.Inflight.Set(msg.ToPacket())
}
}
}
@@ -1357,24 +1423,7 @@ func (s *Server) loadInflight(v []storage.Message) {
// loadRetained restores retained messages from the datastore.
func (s *Server) loadRetained(v []storage.Message) {
for _, msg := range v {
s.Topics.RetainMessage(packets.Packet{
FixedHeader: msg.FixedHeader,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
s.Topics.RetainMessage(msg.ToPacket())
}
}
@@ -1412,8 +1461,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
// clearExpiredInflights deletes any inflight messages which have expired.
func (s *Server) clearExpiredInflights(now int64) {
for _, client := range s.Clients.GetAll() {
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
s.hooks.OnExpireInflights(client, now)
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
for _, id := range deleted {
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,11 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
@@ -19,6 +22,7 @@ type Info struct {
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
@@ -28,3 +32,30 @@ type Info struct {
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

37
system/system_test.go Normal file
View File

@@ -0,0 +1,37 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -300,6 +301,9 @@ func NewTopicsIndex() *TopicsIndex {
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
@@ -319,6 +323,9 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
d = 2
@@ -345,7 +352,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
@@ -360,6 +372,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
@@ -618,6 +631,7 @@ type particle struct {
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (

View File

@@ -32,7 +32,7 @@ type histogram struct {
valueCount int64 // number of values recorded for single value
}
// AddMeasurement records a value measurement observation to the histogram.
// addMeasurement records a value measurement observation to the histogram.
func (h *histogram) addMeasurement(value int64) {
// TODO: assert invariant
h.sum += value

View File

@@ -395,7 +395,7 @@ func New(family, title string) Trace {
}
func (tr *trace) Finish() {
elapsed := time.Now().Sub(tr.Start)
elapsed := time.Since(tr.Start)
tr.mu.Lock()
tr.Elapsed = elapsed
tr.mu.Unlock()

View File

@@ -1,30 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unsafeheader contains header declarations for the Go runtime's
// slice and string implementations.
//
// This package allows x/sys to use types equivalent to
// reflect.SliceHeader and reflect.StringHeader without introducing
// a dependency on the (relatively heavy) "reflect" package.
package unsafeheader
import (
"unsafe"
)
// Slice is the runtime representation of a slice.
// It cannot be used safely or portably and its representation may change in a later release.
type Slice struct {
Data unsafe.Pointer
Len int
Cap int
}
// String is the runtime representation of a string.
// It cannot be used safely or portably and its representation may change in a later release.
type String struct {
Data unsafe.Pointer
Len int
}

31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s generated vendored Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
// +build darwin freebsd netbsd openbsd
// +build gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && !aix
// +build gccgo,!aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build gccgo
// +build !aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
#include <errno.h>
#include <stdint.h>

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package unix

View File

@@ -4,9 +4,7 @@
package unix
import (
"unsafe"
)
import "unsafe"
// IoctlRetInt performs an ioctl operation specified by req on a device
// associated with opened file descriptor fd, and returns a non-negative
@@ -217,3 +215,19 @@ func IoctlKCMAttach(fd int, info KCMAttach) error {
func IoctlKCMUnattach(fd int, info KCMUnattach) error {
return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info))
}
// IoctlLoopGetStatus64 gets the status of the loop device associated with the
// file descriptor fd using the LOOP_GET_STATUS64 operation.
func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
var value LoopInfo64
if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil {
return nil, err
}
return &value, nil
}
// IoctlLoopSetStatus64 sets the status of the loop device associated with the
// file descriptor fd using the LOOP_SET_STATUS64 operation.
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
}

View File

@@ -73,12 +73,12 @@ aix_ppc64)
darwin_amd64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm_darwin.go"
mkasm="go run mkasm.go"
;;
darwin_arm64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm_darwin.go"
mkasm="go run mkasm.go"
;;
dragonfly_amd64)
mkerrors="$mkerrors -m64"
@@ -142,42 +142,60 @@ netbsd_arm64)
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_386)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m32"
mksyscall="go run mksyscall.go -l32 -openbsd"
mksyscall="go run mksyscall.go -l32 -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_amd64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_arm)
mkasm="go run mkasm.go"
mkerrors="$mkerrors"
mksyscall="go run mksyscall.go -l32 -openbsd -arm"
mksyscall="go run mksyscall.go -l32 -openbsd -arm -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_arm64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_mips64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_ppc64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_riscv64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
@@ -214,11 +232,6 @@ esac
if [ "$GOOSARCH" == "aix_ppc64" ]; then
# aix/ppc64 script generates files instead of writing to stdin.
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ;
elif [ "$GOOS" == "darwin" ]; then
# 1.12 and later, syscalls via libSystem
echo "$mksyscall -tags $GOOS,$GOARCH,go1.12 $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go";
# 1.13 and later, syscalls via libSystem (including syscallPtr)
echo "$mksyscall -tags $GOOS,$GOARCH,go1.13 syscall_darwin.1_13.go |gofmt >zsyscall_$GOOSARCH.1_13.go";
elif [ "$GOOS" == "illumos" ]; then
# illumos code generation requires a --illumos switch
echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go";
@@ -232,5 +245,5 @@ esac
if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi
if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi
if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go | go run mkpost.go > ztypes_$GOOSARCH.go"; fi
if [ -n "$mkasm" ]; then echo "$mkasm $GOARCH"; fi
if [ -n "$mkasm" ]; then echo "$mkasm $GOOS $GOARCH"; fi
) | $run

View File

@@ -642,7 +642,7 @@ errors=$(
signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort
)
@@ -652,7 +652,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort >_signal.grep
echo '// mkerrors.sh' "$@"

View File

@@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
return msgs, nil
}
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
// message data (a slice of b), and the remainder of b after that single message.
// When there are no remaining messages, len(remainder) == 0.
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
h, dbuf, err := socketControlMessageHeaderAndData(b)
if err != nil {
return Cmsghdr{}, nil, nil, err
}
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
remainder = b[i:]
}
return *h, dbuf, remainder, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {

27
vendor/golang.org/x/sys/unix/str.go generated vendored
View File

@@ -1,27 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package unix
func itoa(val int) string { // do it here rather than with fmt to avoid dependency
if val < 0 {
return "-" + uitoa(uint(-val))
}
return uitoa(uint(val))
}
func uitoa(val uint) string {
var buf [32]byte // big enough for int64
i := len(buf) - 1
for val >= 10 {
buf[i] = byte(val%10 + '0')
i--
val /= 10
}
buf[i] = byte(val + '0')
return string(buf[i:])
}

View File

@@ -29,8 +29,6 @@ import (
"bytes"
"strings"
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
// ByteSliceFromString returns a NUL-terminated slice of bytes
@@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string {
ptr = unsafe.Pointer(uintptr(ptr) + 1)
}
var s []byte
h := (*unsafeheader.Slice)(unsafe.Pointer(&s))
h.Data = unsafe.Pointer(p)
h.Len = n
h.Cap = n
return string(s)
return string(unsafe.Slice(p, n))
}
// Single-word zero for use when we need a valid pointer to 0 bytes.

View File

@@ -253,7 +253,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var empty bool
if len(oob) > 0 {
// send at least one normal byte
empty := emptyIovecs(iov)
empty = emptyIovecs(iov)
if empty {
var iova [1]Iovec
iova[0].Base = &dummy

View File

@@ -363,7 +363,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var empty bool
if len(oob) > 0 {
// send at least one normal byte
empty := emptyIovecs(iov)
empty = emptyIovecs(iov)
if empty {
var iova [1]Iovec
iova[0].Base = &dummy

View File

@@ -1,32 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12 && !go1.13
// +build darwin,go1.12,!go1.13
package unix
import (
"unsafe"
)
const _SYS_GETDIRENTRIES64 = 344
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// To implement this using libSystem we'd need syscall_syscallPtr for
// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall
// back to raw syscalls for this func on Go 1.12.
var p unsafe.Pointer
if len(buf) > 0 {
p = unsafe.Pointer(&buf[0])
} else {
p = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0)
n = int(r0)
if e1 != 0 {
return n, errnoErr(e1)
}
return n, nil
}

View File

@@ -1,108 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.13
// +build darwin,go1.13
package unix
import (
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
var s []byte
hdr := (*unsafeheader.Slice)(unsafe.Pointer(&s))
hdr.Data = unsafe.Pointer(&entry)
hdr.Cap = reclen
hdr.Len = reclen
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}

View File

@@ -19,6 +19,96 @@ import (
"unsafe"
)
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}
// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets.
type SockaddrDatalink struct {
Len uint8
@@ -140,6 +230,7 @@ func direntNamlen(buf []byte) (uint64, bool) {
func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) }
func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) }
func PtraceDenyAttach() (err error) { return ptrace(PT_DENY_ATTACH, 0, 0, 0) }
//sysnb pipe(p *[2]int32) (err error)

View File

@@ -255,6 +255,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

View File

@@ -319,6 +319,7 @@ func PtraceSingleStep(pid int) (err error) {
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

View File

@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
}
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint32(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
}
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint32(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

22
vendor/golang.org/x/sys/unix/syscall_hurd.go generated vendored Normal file
View File

@@ -0,0 +1,22 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build hurd
// +build hurd
package unix
/*
#include <stdint.h>
int ioctl(int, unsigned long int, uintptr_t);
*/
import "C"
func ioctl(fd int, req uint, arg uintptr) (err error) {
r0, er := C.ioctl(C.int(fd), C.ulong(req), C.uintptr_t(arg))
if r0 == -1 && er != nil {
err = er
}
return
}

29
vendor/golang.org/x/sys/unix/syscall_hurd_386.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 && hurd
// +build 386,hurd
package unix
const (
TIOCGETA = 0x62251713
)
type Winsize struct {
Row uint16
Col uint16
Xpixel uint16
Ypixel uint16
}
type Termios struct {
Iflag uint32
Oflag uint32
Cflag uint32
Lflag uint32
Cc [20]uint8
Ispeed int32
Ospeed int32
}

Some files were not shown because too many files have changed in this diff Show More