Compare commits

...

234 Commits

Author SHA1 Message Date
mochi-co
aac186dcc1 Add newline for godoc formatting 2022-12-11 22:25:21 +00:00
JB
42931f332f Update badges to use v2 references 2022-12-11 21:44:44 +00:00
mochi-co
8a04648c09 Cleanup godoc formatting 2022-12-11 21:38:01 +00:00
JB
854c033fb6 Update README.md 2022-12-11 12:21:25 +00:00
mochi-co
74ed8cd046 Update go mod and imports to v2 2022-12-11 11:50:44 +00:00
JB
be164fa715 Update go.mod 2022-12-11 11:43:55 +00:00
JB
4287955161 Update go.mod 2022-12-11 11:42:43 +00:00
JB
bbf08ff496 Update README.md 2022-12-10 22:51:01 +00:00
JB
c38201ff8b Update README.md 2022-12-10 22:32:53 +00:00
JB
f8b4ff5c0d Update README.md 2022-12-10 22:29:11 +00:00
JB
661e23e051 Update README.md 2022-12-10 22:11:39 +00:00
JB
9c99db426c Update build.yml 2022-12-10 21:58:03 +00:00
mochi-co
40b7273a53 update github workflow go version to 1.19 2022-12-10 21:52:06 +00:00
mochi-co
898c90d4ca Rewrite everything from scratch for mqtt v5 2022-12-10 21:49:32 +00:00
JB
bc3d8b0eaa Update README.md 2022-11-14 21:57:14 +00:00
JB
35bd928714 Update README.md 2022-11-14 21:56:29 +00:00
JB
20c2655d0e Update README.md 2022-10-01 11:29:20 +01:00
JB
fec25f29c6 Update README.md 2022-10-01 00:23:07 +01:00
JB
1d7a322229 Update README.md 2022-10-01 00:22:08 +01:00
JB
d8b38a4ae2 Update README.md 2022-09-16 13:37:35 +01:00
JB
a83c0c4fd0 Contributions paused pending v2.0.0 2022-09-16 13:37:07 +01:00
JB
66a1d19e89 Update README.md 2022-09-16 13:32:53 +01:00
mochi-co
0dbebbc066 Revert "Merge pull request #97 from alexsporn/fix/writer-full"
This reverts commit f22b8276e8, reversing
changes made to b2fc287a98.
2022-09-11 22:45:57 +01:00
JB
f22b8276e8 Merge pull request #97 from alexsporn/fix/writer-full
Instead of waiting for the writing buffer to have enough space, skip writing and return an error
2022-09-10 18:20:09 +01:00
Alexander Sporn
d60c438960 Merge branch 'mochi-co:master' into fix/writer-full 2022-09-09 16:07:21 +02:00
JB
b2fc287a98 Merge pull request #99 from mochi-co/fix-inflight-race
Fix concurrent map access for clients and inflights causes data race
2022-09-02 21:05:47 +01:00
mochi-co
3e3ba20b08 Return copies of client and inflight maps to avoid missed locks 2022-09-02 20:54:20 +01:00
mochi-co
9ee462c777 Increase inlinepub messages buffer 2022-09-02 20:53:34 +01:00
Alexander Sporn
3c89114bba Instead of waiting for the writing buffer to have enough space, skip writing and return an error 2022-09-02 12:23:48 +02:00
mochi-co
ecbd07fa3a Check against the correct clean session var for abandoning old inflights 2022-08-18 00:19:11 +01:00
zynzel
ad8bf2a931 Keep in sync server.System.Inflight (#92)
* Keep in sync server.System.Inflight

* Fix args order in tests
2022-08-17 23:58:43 +01:00
JB
b8fb068bb9 Update README.md 2022-08-16 22:21:23 +01:00
JB
c1348a37b8 Update README.md 2022-08-16 22:20:50 +01:00
mochi-co
84fc2f848b Abandon inflights at the end of clean-session connections 2022-08-16 21:41:39 +01:00
JB
8703d6d020 Merge pull request #90 from mochi-co/resend-inflights
Adds Inflight TTL and Period Resend
2022-08-16 21:31:42 +01:00
mochi-co
666440fe56 Adds Inflight TTL and Period Resend 2022-08-16 21:19:42 +01:00
JB
1ae050939a Merge pull request #84 from mochi-co/goreport-fixes
Goreport fixes
2022-06-22 15:52:36 +01:00
mochi
f4683d27d0 remove ineffective assignments 2022-06-22 15:45:13 +01:00
mochi
dff2b1db30 apply gofmt -s 2022-06-22 15:40:52 +01:00
JB
9de6b4e427 Merge pull request #83 from mochi-co/tls-client-auth
Expose tls.Config to Listeners
2022-06-22 15:32:23 +01:00
JB
78c1914270 Merge pull request #82 from mochi-co/expose-event-client-username
Add CleanSession and Username to events.Client struct
2022-06-22 15:31:31 +01:00
mochi
f71bf5c3d6 use TLSConfig instead of deprecated TLS field 2022-06-22 15:26:51 +01:00
mochi
53c4a6b09f Add TLSConfig field to allow direct tls.Config setting 2022-06-22 15:26:26 +01:00
mochi
a02c6bd8df update TLS example to use TLSConfig field 2022-06-22 15:25:52 +01:00
mochi
d8f6d63cc8 Add CleanSession and Username to events.Client struct 2022-06-22 12:33:09 +01:00
mochi
bef13eec20 Add OnSubscribe, OnUnsubscribe events examples 2022-05-04 12:58:23 +01:00
mochi
27f3c484ad Extend onsusbcribe, onunsubscribe events 2022-05-04 12:53:04 +01:00
JB
9b5cdb0bcc Merge pull request #74 from muXxer/feat/topic-subscription-events 2022-05-04 12:33:12 +01:00
muXxer
2b60a11d4a Add topic un-/subscribe events 2022-04-28 00:48:20 +02:00
JB
b53774f818 Merge pull request #72 from BoskyWSMFN/master
fix-panic
2022-04-19 08:58:46 +01:00
BoskyWSMFN
7dee729afb fix-panic
fixed runtime panic in server/internal/circ/pool.go occurring on 32-bits architectures caused by misalignment of BytesPool struct members.

https://github.com/golang/go/issues/36606#issue-551005857
2022-04-19 00:26:44 +03:00
mochi
aed535b7bf fix comments 2022-04-13 10:46:13 +01:00
mochi
4ff888ab3b Add auth controller example 2022-04-13 10:37:24 +01:00
JB
31252c081b Add Docker info 2022-04-10 20:46:16 +01:00
JB
3a15cc3add Merge pull request #69 from mochi-co/v1.2.0
V1.2.0
2022-04-10 20:23:43 +01:00
mochi
765f6e7c2e use NewServer instead of New 2022-04-10 20:19:06 +01:00
mochi
e3bdfc1f8e update readme for new events 2022-04-10 20:18:18 +01:00
JB
60cd972b7f Merge pull request #68 from mochi-co/fix-store-retained
Fix Store Retained Messages
2022-04-10 19:55:34 +01:00
mochi
ee081d0abe Merge branch 'v1.2.0' of https://github.com/mochi-co/mqtt into fix-store-retained 2022-04-10 19:54:37 +01:00
JB
d97b4bb81d Merge pull request #67 from mochi-co/release-client-buffers
Release client buffers
2022-04-10 19:53:46 +01:00
mochi
94aeacf0cb only check final outcome due to races 2022-04-10 19:46:12 +01:00
mochi
5cb8a081a1 accept any error for invalid protocol due to races 2022-04-10 19:43:59 +01:00
mochi
fc00112e47 Check for protocol violation errors 2022-04-10 19:42:37 +01:00
mochi
bbc22fae5b Add comments 2022-04-10 19:30:26 +01:00
mochi
82cb75913d Remove unused code 2022-04-10 19:28:22 +01:00
mochi
45c4a64b87 Abandon client state if the existing client specified a cleansession 2022-04-10 19:07:46 +01:00
mochi
bae2579497 Expose CleanSession value for checking 2022-04-10 19:07:15 +01:00
mochi
1bc01271cb Store retained message based on corrected r value 2022-04-10 18:47:45 +01:00
mochi
352a71f50c Expect correct r values for RetainMessage 2022-04-10 18:46:46 +01:00
mochi
6f9f62e38f Correctly return R value of retainMessage 2022-04-10 18:46:23 +01:00
mochi
0f67d9e8ff Update EstablishConnection tests to ensure buffers and pool are correctly released after use 2022-04-10 17:37:46 +01:00
mochi
f2dd5b63ae Export R/W buffer values so they can be assessed in tests without causing races 2022-04-10 17:37:25 +01:00
mochi
54e2d044a2 Track number of pool blocks in use 2022-04-10 17:36:46 +01:00
mochi
6298a87298 Use package errors instead of strings 2022-04-10 14:58:59 +01:00
mochi
b6fd25bba4 test clarbuffers 2022-04-10 01:46:18 +01:00
mochi
eef3592576 clear buffers after deferred stop 2022-04-10 01:46:04 +01:00
mochi
5d343c12e1 refactor clients for buffer releasing 2022-04-10 01:34:01 +01:00
mochi
70f52c8a3b Refactor establishconnection to prevent same-id disconnects 2022-04-10 01:33:35 +01:00
mochi
429b72265a refactor connSetup for clarity 2022-04-10 01:32:24 +01:00
mochi
f60d2dcfca Clarify error messages 2022-04-10 01:32:04 +01:00
mochi
6674cd64eb clarify error checking 2022-04-10 01:31:36 +01:00
mochi
f218cde69c Use defer to release buffers and decrease stats on any client closure 2022-04-09 20:53:28 +01:00
JB
9ea687eb94 Merge pull request #61 from mochi-co/server-options
Configurable Server Options
2022-04-09 20:29:23 +01:00
JB
949e4e2e91 Merge pull request #63 from mochi-co/add-drop-packet-error
Add ErrRejectPacket to OnProcessMessage
2022-04-09 20:29:03 +01:00
JB
515e0269de Merge pull request #62 from mochi-co/fix-inflight-key
Fix Inflight Persistence Key
2022-04-06 11:23:51 +01:00
mochi
ae6073c79c track logged error 2022-03-31 18:23:12 +01:00
mochi
b072a08f0b Update test to check for packet rejection 2022-03-31 18:04:41 +01:00
mochi
01d8a450d2 Ensure OnError is set before using it 2022-03-31 17:56:08 +01:00
mochi
da2fd41f79 Update OnProcessMessage documentation 2022-03-31 17:53:34 +01:00
mochi
56e8039093 Optionally drop a packet if the ErrRejectPacket error is returned from OnProcessMessage 2022-03-31 17:53:13 +01:00
mochi
7b9bc844c1 Add ErrRejectPacket error to abandon packet processing from OnMessageProcess 2022-03-31 17:52:40 +01:00
JB
8acb182820 Merge pull request #53 from stffabi/feature/onprocessmessage-event
Events: Add OnProcessMessage event
2022-03-31 17:41:22 +01:00
mochi
5726880095 fix inflight key reference 2022-03-31 17:36:03 +01:00
mochi
ee459e1b3d Fix code block formatting 2022-03-31 17:32:17 +01:00
mochi
6aec3a8bbf Update readme with server options 2022-03-31 17:21:42 +01:00
mochi
74699f0a87 Add example implementation 2022-03-31 17:21:35 +01:00
mochi
70def39ff9 add tests for new NewServer function 2022-03-31 17:00:00 +01:00
mochi
8e7098a32d remove deprecated log message 2022-03-31 16:59:49 +01:00
mochi
e4f02919fd Update code to use new NewServer function instead of deprecated New 2022-03-31 16:49:27 +01:00
mochi
99c96c844e Update example code to use new NewServer function instead of deprecated New 2022-03-31 16:49:02 +01:00
mochi
18629aea6d Use internal default values instead of relying on passed value 2022-03-31 16:48:42 +01:00
mochi
a0060429d1 Add Server Options
Adds a new struct of server options which can be used to override default properties. A new options-accepting NewServer function has been created to supersede the New method, which is now deprecated.
2022-03-31 16:48:29 +01:00
mochi
d946a9ae16 Update go mod to ensure bolt is using 1.3.5
Bolt 1.3.6 fails to build correctly and has been removed, so rollback to bolt 1.3.5. Also upgrade to Go 1.18
2022-03-31 16:39:28 +01:00
JB
51a2eb5f48 Merge pull request #57 from hybridgroup/v1.2.0-docker
docker: add initial simple Dockerfile
2022-03-30 09:37:53 +01:00
JB
0d4b0a89d8 Merge pull request #58 from soyoo/patch-1
typo
2022-03-30 09:30:11 +01:00
soyoo
0e7ccfe3fb typo 2022-03-25 16:55:57 +08:00
Ron Evans
5d7230630d docker: add initial simple Dockerfile 2022-03-23 10:37:20 +01:00
JB
6a3cbd6093 Merge pull request #51 from jmacd/jmacd/noracefix
Two no-functional-change cleanups combined
2022-03-22 16:19:24 +00:00
stffabi
7b4e79707b Events: Add OnProcessMessage event
This event gets called right after ACL checking but before any other
Fields of the packet get evaluated.
2022-03-21 10:11:29 +01:00
Joshua MacDonald
c6643592f6 Combines two fixes 2022-03-17 14:19:23 -07:00
mochi
5de12d0460 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2 2022-03-17 12:53:59 +00:00
JB
0a7205e110 Update README.md 2022-03-17 12:50:01 +00:00
JB
8133dd8299 Update README.md 2022-03-17 12:48:37 +00:00
JB
fdbfff57dc Merge pull request #46 from stffabi/bugfix/acls-retain
Subscribe: Only send retained messages if ACLs has allowed subscription to the topic
2022-03-17 09:41:48 +00:00
stffabi
f5fc5e8c44 Subscribe: Only send retained messages if ACLs has allowed subscription to the topic 2022-03-17 09:25:29 +01:00
mochi
9f44712b80 Fix incorrect test
The previous publish inline test incorrectly approved retain packets without retain=true fixedheader values.
2022-03-16 18:16:48 +00:00
stffabi
1f86168d9d Publish: Set the retain flag in the fixedheader (#42)
* Publish: Set the retain flag in the fixedheader
2022-03-16 18:12:07 +00:00
mochi
ab25083ed2 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2
# Conflicts:
#	server/internal/clients/clients.go
2022-03-16 18:09:00 +00:00
JB
9b0aa4d559 Update README.md 2022-03-15 20:04:04 +00:00
JB
03814944a9 Update README.md 2022-03-15 19:58:58 +00:00
JB
3286d5a484 Replace Travis with Github Actions (#41)
* Remove Travis CI

* Add Github Actions Workflow

* Update badges for build status, coverage, report card, doc reference

* use actions for all pull requests and pushes

* test all files for coverage

* Apply gofmt -s to simplify code

* Fix typos

* Cleanup comments

* Cleanup comments

Co-authored-by: mochi <mochimou@icloud.com>
2022-03-15 19:56:42 +00:00
mochi
7e970d3c7a Fix typo 2022-03-15 19:13:24 +00:00
mochi
d6a92cc5bd Add Keyed fields to events.Client for readability and go vet 2022-03-15 18:44:49 +00:00
mochi
325d44d478 Add missing method comments 2022-03-15 18:44:21 +00:00
Joshua MacDonald
0a5f6d3a9d Add an OnError handler; report the reason for disconnects. (#38) 2022-03-15 17:59:52 +00:00
Joshua MacDonald
17253ad8bd Wrap packet errors with cause information (#39) 2022-03-15 17:34:49 +00:00
Joshua MacDonald
9f1c387091 Move two WaitGroup.Add calls (#36) 2022-03-15 17:33:31 +00:00
JB
9c6f602630 Merge pull request #29 from jmacd/jmacd/payload_not_utf8
Support non-UTF8 payloads (per MQTT specification)
2022-02-27 08:36:40 +00:00
Joshua MacDonald
b0dcaabdde Support non-UTF8 payloads per MQTT specification 2022-02-26 22:53:51 -08:00
JB
460f0ef681 revert redis update 2022-02-24 21:19:46 +00:00
JB
6e16765f60 revert server version 2022-02-24 21:19:22 +00:00
JB
2b361df19e Merge pull request #27 from mochi-co/revert-26-master
Revert "added redis persistence mode"
2022-02-24 21:15:08 +00:00
JB
c8c0a5a094 Revert "added redis persistence mode" 2022-02-24 21:10:39 +00:00
JB
4a833dd081 Update server version 2022-02-24 21:07:54 +00:00
JB
81198d9845 Update README.md 2022-02-24 21:07:24 +00:00
JB
6c12d8a71a Merge pull request #26 from wind-c/master
added redis persistence mode
2022-02-24 21:05:35 +00:00
narwal
19b598b672 redis and trie 2022-02-23 16:00:32 +08:00
narwal
b6529f05d3 add redis persistence mode and example 2022-02-22 18:57:31 +08:00
mochi
7f76445cc8 update server version 2022-01-30 10:39:49 +00:00
JB
b1c01792cd Merge pull request #24 from mochi-co/feature/optimise-struct-fields
Optimise Struct Fields + Fixes
2022-01-30 10:38:28 +00:00
mochi
eda03d4338 optimise Server struct 2022-01-30 10:30:34 +00:00
mochi
18070f1f57 pass byte pool by address 2022-01-30 10:30:19 +00:00
mochi
7f10c28a37 remove println 2022-01-30 10:30:01 +00:00
mochi
122531bb27 Pass inflight by address to avoid lock copying 2022-01-28 21:07:10 +00:00
mochi
e6dbcae428 Correct function signature 2022-01-28 21:06:57 +00:00
mochi
98875de568 Update test to match new FixedHeader struct 2022-01-28 21:06:43 +00:00
mochi
c9fd9451af Prevent locks from being copied 2022-01-28 21:06:24 +00:00
mochi
6550b8d680 8bit align struct fields 2022-01-28 21:05:50 +00:00
mochi
a60c96c889 Update comment for clarity 2022-01-28 21:04:15 +00:00
mochi
86e0a5827e Update version to 1.1.0 2022-01-26 20:49:53 +00:00
mochi
06c399b606 indicate ARM32 compatibility 2022-01-26 20:49:42 +00:00
JB
ed117f67a1 Merge pull request #22 from mochi-co/feature/32bit-compatibility
ARM32 Compatibility
2022-01-26 20:36:13 +00:00
JB
880a3299e1 Merge pull request #19 from rkennedy/bugfix/32-bit-atomic-alignment
Improve 32-bit compatibility
2022-01-26 08:02:57 +00:00
Rob Kennedy
1c408d05be Fix encodeLength for 32-bit platforms
When `int` is 32 bits, `MaxInt64` doesn't fit. It's apparent that
`encodeLength` expects to handle 64-bit inputs, so let's make that
explicit, which allows the test to run on all platforms.
2022-01-25 00:22:26 -06:00
Rob Kennedy
fce495f83e Avoid race condition when closing listeners
"Atomic load" followed by "atomic store" is not itself an atomic
operation. This commit replaces that sequence with CompareAndSwap
instead.
2022-01-25 00:22:26 -06:00
Rob Kennedy
471ca00a64 Make atomics work on 32-bit systems
On 32-bit systems, `atomic` requires its 64-bit arguments to have 64-bit
alignment, but the compiler doesn't help ensure that's the case. In this
commit, fields that don't need to hold large numbers have been converted
to 32-bit types, which are always aligned correctly on all platforms.
For fields that may hold large numeric values, padding has been added to
get the necessary alignment, and tests have been added to avoid
regressions.
2022-01-25 00:22:26 -06:00
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
mochi
fb136483d0 Revert server version 2022-01-10 23:50:40 +00:00
mochi
b209cd95f1 increment server version 2022-01-10 23:48:33 +00:00
mochi
3a7e58ec01 Remove unnecessary fmt import 2022-01-10 23:47:33 +00:00
mochi
a674632cce Increment server version 2022-01-10 23:41:46 +00:00
JB
09ddc412c7 Merge pull request #12 from jphastings/remove-erroneous-print 2022-01-10 23:38:33 +00:00
JP Hastings-Spital
6fbd8a5eb2 Remove unnecessary println 2022-01-10 23:36:33 +00:00
JB
d4ae73a97c fix indentation in code blocks
convert tabs to spaces
2022-01-05 21:43:47 +00:00
JB
3ff853a990 Update README.md 2022-01-05 21:41:45 +00:00
mochi
4302eed84f Update vendor 2022-01-05 21:28:00 +00:00
mochi
a1fee6ff68 Update go mod to 1.17 2022-01-05 21:27:52 +00:00
mochi
7fbc0b0187 fix code indents 2022-01-05 21:26:11 +00:00
mochi
8bbca347c4 Update go to 1.17 2022-01-05 21:21:35 +00:00
mochi
b277600823 Increment server version to 1.0.1 2022-01-05 21:14:11 +00:00
JB
685c050fdd Merge pull request #11 from mochi-co/feature/event-hooks-publish
Feature/event hooks publish
2022-01-05 21:13:01 +00:00
mochi
0abbaf5070 fix onmessage test 2022-01-05 21:09:12 +00:00
mochi
1ab1928cff change scheduled message for clarity 2022-01-05 21:05:31 +00:00
mochi
8890bb9dd4 remove redundant code 2022-01-05 21:05:20 +00:00
mochi
f9348aaf93 Update Readme to add Event Hooks section 2022-01-05 20:59:25 +00:00
mochi
c2a42a16ca Merge OnMessage and OnMessageModify 2022-01-05 20:59:14 +00:00
mochi
d14d944de9 Update events example with publish hooks 2022-01-05 20:38:32 +00:00
mochi
480e60b3f0 Adds tests for publishing event hooks 2022-01-05 20:38:10 +00:00
mochi
d4cbf1abdc Add Event Hooks
Adds basic event hooks (OnMessage, OnMessageModify) to the server using the new events library.
2022-01-05 20:38:00 +00:00
mochi
8a1c53432e Add Events
Events library contains event hook types and related utility functions
2022-01-05 20:37:15 +00:00
mochi
7c7b8d58fe Return packets to internal
Now that we can alias types, there's no compelling reason to expose the packets library
2022-01-05 18:10:24 +00:00
JB
ce773b3978 Merge pull request #10 from mochi-co/expose-packets
Expose packets library
2022-01-05 17:06:36 +00:00
JB
f3e7469478 Merge pull request #8 from mochi-co/feature/inline-publish
Inline Publishing
2022-01-05 17:02:14 +00:00
mochi
b5685ca0ee update packets library import reference 2022-01-05 17:01:15 +00:00
mochi
66edb0564c expose packets library 2022-01-05 17:00:51 +00:00
mochi
1d9fa4199c Add .DS_Store to ignore list 2022-01-05 17:00:17 +00:00
mochi
dec880231d Update with direct publishing
Adds information about direct publishing and moves performance section
2022-01-05 13:49:41 +00:00
mochi
21d4e54e74 Add inline publishing example
Adds an example file which demonstrates the usage of the `Publish` method. This file will also be used to demonstrate event hooks.
2022-01-05 13:48:17 +00:00
mochi
aeb4190733 Add tests for new inline publishing method 2022-01-05 13:32:28 +00:00
mochi
484e4abd56 Directly publish messages from embedding system
When the broker is embedded in a larger Go codebase, it is beneficial to be able to publish messages directly from the system to topics. This change provides a Publish method which adds messages to an inline publishing queue in a separate goroutine, which are then processed in the standard way and issued to all clients with matching topic filters.
2022-01-05 13:32:12 +00:00
mochi
d51bad30fc Update comments and rename input parameter for clarity 2022-01-05 13:14:50 +00:00
mochi
060fbffa79 Update comments for clarity 2022-01-05 13:14:15 +00:00
mochi
7c68614912 Add .gitignore
Ensure we're not committing any binaries
2022-01-05 13:13:54 +00:00
JB
124be96c0e Remove Codacy badge 2021-11-01 21:54:40 +00:00
Mochi
b08a57eb89 Update Readme 2020-02-12 22:56:49 +00:00
Mochi
e8e29e95cf Update Readme 2020-02-12 22:49:11 +00:00
Mochi
8e468852b2 Update Chart Labels 2020-02-12 22:48:50 +00:00
Mochi
7c23925ec6 Update Readme 2020-02-12 21:59:56 +00:00
Mochi
aa90dd80b3 Update Badges 2020-02-12 21:47:49 +00:00
Mochi
a98e16790b Update Badges 2020-02-12 21:47:21 +00:00
Mochi
bec9401213 Fix test races 2020-02-12 21:15:15 +00:00
Mochi
7103f0439a Update travis 2020-02-12 20:43:11 +00:00
Mochi
b605c94e5b Fix examples 2020-02-12 20:39:30 +00:00
Mochi
31a026b14a Add TravisCI 2020-02-12 20:35:53 +00:00
Mochi
4fdf2ae2fe Add badges 2020-02-12 20:30:17 +00:00
1013 changed files with 210934 additions and 132853 deletions

43
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
cmd/mqtt
.DS_Store
*.db

103
.golangci.yml Normal file
View File

@@ -0,0 +1,103 @@
linters:
disable-all: false
fix: false # Fix found issues (if it's supported by the linter).
enable:
# - asasalint
# - asciicheck
# - bidichk
# - bodyclose
# - containedctx
# - contextcheck
#- cyclop
# - deadcode
- decorder
# - depguard
# - dogsled
# - dupl
- durationcheck
# - errchkjson
# - errname
- errorlint
# - execinquery
# - exhaustive
# - exhaustruct
# - exportloopref
#- forcetypeassert
#- forbidigo
#- funlen
#- gci
# - gochecknoglobals
# - gochecknoinits
# - gocognit
# - goconst
# - gocritic
- gocyclo
- godot
# - godox
# - goerr113
# - gofmt
# - gofumpt
# - goheader
- goimports
# - golint
# - gomnd
# - gomoddirectives
# - gomodguard
# - goprintffuncname
- gosec
- gosimple
- govet
# - grouper
# - ifshort
- importas
- ineffassign
# - interfacebloat
# - interfacer
# - ireturn
# - lll
# - maintidx
# - makezero
- maligned
- misspell
# - nakedret
# - nestif
# - nilerr
# - nilnil
# - nlreturn
# - noctx
# - nolintlint
# - nonamedreturns
# - nosnakecase
# - nosprintfhostport
# - paralleltest
# - prealloc
# - predeclared
# - promlinter
- reassign
# - revive
# - rowserrcheck
# - scopelint
# - sqlclosecheck
# - staticcheck
# - structcheck
# - stylecheck
# - tagliatelle
# - tenv
# - testpackage
# - thelper
- tparallel
# - typecheck
- unconvert
- unparam
- unused
- usestdlibvars
# - varcheck
# - varnamelen
- wastedassign
- whitespace
# - wrapcheck
# - wsl
disable:
- errcheck

31
Dockerfile Normal file
View File

@@ -0,0 +1,31 @@
FROM golang:1.19.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

View File

@@ -1,7 +1,7 @@
The MIT License (MIT)
Copyright (c) 2019 Jonathan Blake (mochi)
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

454
README.md
View File

@@ -1,178 +1,382 @@
# Mochi MQTT
### A High-performance MQTT server in Go (v3.0 | v3.1.1)
Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with world favourites such as Mosquitto, Mosca, and VerneMQ.
<p align="center">
#### What is MQTT?
MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq)
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
#### Mochi MQTT Features
- Paho MQTT 3.0 / 3.1.1 compatible.
- Full MQTT Feature-set (QoS, Retained, $SYS)
- Trie-based Subscription model.
- Ring Buffer packet codec.
- TCP, Websocket, (including SSL/TLS) and Dashboard listeners.
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
</p>
#### Roadmap
- Inline Pub-sub (without client) and event hooks
- Docker Image
- MQTT v5 compatibility
# Mochi MQTT Broker
## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
#### Performance (messages/second)
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a 13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
### What is MQTT?
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
> As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers.
## What's new in Version 2.0.0?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
**Single Client, 10,000 messages**
Don't forget to use the new v2 import paths:
```go
import "github.com/mochi-co/mqtt/v2"
```
![1 Client, 10,000 Messages](assets/benchmarkchart_1_10000.png "1 Client, 10,000 Messages")
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
- User and MQTTv5 Packet Properties
- Topic Aliases
- Shared Subscriptions
- Subscription Options and Subscription Identifiers
- Message Expiry
- Client Session Expiry
- Send and Receive QoS Flow Control Quotas
- Server-side Disconnect and Auth Packets
- Will Delay Intervals
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
- Developer-centric:
- Most core broker code is now exported and accessible, for total developer control.
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=1 -num-messages=10000`
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND High | 36505 | 30597 | 27202 | 32782 | 30125 |
| SEND Low | 36505 | 30597 | 27202 | 32782 | 30125 |
| SEND Median | 36505 | 30597 | 27202 |32782 | 30125 |
| RECV High | 152221 | 59130 | 7879 | 17551 | 9145 |
| RECV Low | 152221 | 59130 | 7879 | 17551 | 9145 |
| RECV Median | 152221 | 59130 | 7879 | 17551 | 9145 |
### Compatibility Notes
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
**10 Clients, 1,000 Messages**
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
![10 Clients, 1,000 Messages](assets/benchmarkchart_10_1000.png "10 Clients, 1,000 Messages")
## Roadmap
- Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks!
- Cluster support.
- Enhanced Metrics support.
- File-based server configuration (supporting docker).
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=1000`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND High | 37193 | 15775 | 17455 | 34138 | 36575 |
| SEND Low | 6529 | 6446 | 7714 | 8583 | 7383 |
| SEND Median | 15127 | 7813 | 10305 | 9887 | 8169 |
| RECV High | 33535 | 3710 | 3022 | 4534 | 9411 |
| RECV Low | 7484 | 2661 | 1689 | 2021 | 2275 |
| RECV Median | 11427 | 3142 | 1831 | 2468 | 4692 |
**10 Clients, 10,000 Messages**
![10 Clients, 10000 Messages](assets/benchmarkchart_10_10000.png "10 Clients, 10000 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND High | 13153 | 13270 | 12229 | 13025 | 38446 |
| SEND Low | 8728 | 8513 | 8193 | 6483 | 3889 |
| SEND Median | 9045 | 9532 | 9252 | 8031 | 9210 |
| RECV High | 20774 | 5052 | 2093 | 2071 | 43008 |
| RECV Low | 10718 |3995 | 1531 | 1673 | 18764 |
| RECV Median | 16339 | 4607 | 1620 | 1907 | 33524 |
**500 Clients, 100 Messages**
![500 Clients, 100 Messages](assets/benchmarkchart_500_100.png "500 Clients, 100 Messages")
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=500 -num-messages=100`
| | Mochi | Mosquitto | EMQX | VerneMQ | Mosca |
| :----------- | --------: | ----------: | -------: | --------: | --------:
| SEND High | 70688 | 72686 | 71392 | 75336 | 73192 |
| SEND Low | 1021 | 2577 | 1603 | 8417 | 2344 |
| SEND Median | 49871 | 33076 | 33637 | 35200 | 31312 |
| RECV High | 116163 | 4215 | 3427 | 5484 | 10100 |
| RECV Low | 1044 | 156 | 56 | 83 | 169 |
| RECV Median | 24398 | 208 | 94 | 413 | 474 |
#### Using the Broker
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon.
## Quick Start
### Running the Broker with Go
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the [cmd/main.go](cmd/main.go) entrypoint in the [cmd](cmd) folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners.
```
cd cmd
go build -o mqtt && ./mqtt
```
#### Quick Start
### Using Docker
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server:
```sh
docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
```
## Developing with Mochi MQTT
### Importing as a package
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/v2"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New()
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883")
// Add the listener to the server with default options (nil).
err := server.AddListener(tcp, nil)
if err != nil {
log.Fatal(err)
}
// Start the broker. Serve() is blocking - see examples folder
// for usage ideas.
err = server.Serve()
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
if err != nil {
log.Fatal(err)
}
}
```
Examples of running the broker with various configurations can be found in the `examples` folder.
Examples of running the broker with various configurations can be found in the [examples](examples) folder.
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(id, address string)` - A TCP Listener, taking a unique ID and a network address to bind.
- `listeners.NewWebsocket(id, address string)` A Websocket Listener
- `listeners.NewHTTPStats()` An HTTP $SYS info dashboard
##### Configuring Network Listeners
When a listener is added to the server using `server.AddListener`, a `*listeners.Config` may be passed as the second argument.
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
##### Authentication and ACL
Authentication and ACL may be configured on a per-listener basis by providing an Auth Controller to the listener configuration. Custom Auth Controllers should satisfy the `auth.Controller` interface found in `listeners/auth`. Two default controllers are provided, `auth.Allow`, which allows all traffic, and `auth.Disallow`, which denies all traffic.
A `*listeners.Config` may be passed to configure TLS.
Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go).
### Server Options and Capabilities
A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
server := mqtt.New(&mqtt.Options{
Capabilities: mqtt.Capabilities{
MaximumSessionExpiryInterval: 3600,
Compatibilities: mqtt.Compatibilities{
ObscureNotAuthorized: true,
},
},
SysTopicResendInterval: 10,
})
```
> If no auth controller is provided in the listener configuration, the server will default to _Disallowing_ all traffic to prevent unintentional security issues.
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
## Event Hooks
A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools.
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
| Type | Import | Info |
| -- | -- | -- |
| Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
| Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
| Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
| Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
| Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
| Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know!
### Access Control
#### Allow Hook
By default, Mochi MQTT uses a DENY-ALL access control rule. To allow connections, this must overwritten using an Access Control hook. The simplest of these hooks is the `auth.AllowAll` hook, which provides ALLOW-ALL rules to all connections, subscriptions, and publishing. It's also the simplest hook to use:
##### SSL
SSL may be configured on both the TCP and Websocket listeners by providing a public-private PEM key pair to the listener configuration as `[]byte` slices.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: publicCertificate,
PrivateKey: privateKey,
},
})
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
```
> Note the mandatory inclusion of the Auth Controller!
#### Data Persistence
Mochi MQTT provides a `persistence.Store` interface for developing and attaching persistent stores to the broker. The default persistence mechanism packaged with the broker is backed by [Bolt](https://github.com/etcd-io/bbolt) and can be enabled by assigning a `*bolt.Store` to the server.
> Don't do this if you are exposing your server to the internet or untrusted networks - it should really be used for development, testing, and debugging only.
#### Auth Ledger
The Auth Ledger hook provides a sophisticated mechanism for defining access rules in a struct format. Auth ledger rules come in two forms: Auth rules (connection), and ACL rules (publish subscribe).
Auth rules have 4 optional criteria and an assertion flag:
| Criteria | Usage |
| -- | -- |
| Client | client id of the connecting client |
| Username | username of the connecting client |
| Password | password of the connecting client |
| Remote | the remote address or ip of the client |
| Allow | true (allow this user) or false (deny this user) |
ACL rules have 3 optional criteria and an filter match:
| Criteria | Usage |
| -- | -- |
| Client | client id of the connecting client |
| Username | username of the connecting client |
| Remote | the remote address or ip of the client |
| Filters | an array of filters to match |
Rules are processed in index order (0,1,2,3), returning on the first matching rule. See [hooks/auth/ledger.go](hooks/auth/ledger.go) to review the structs.
```go
// import "github.com/mochi-co/mqtt/server/persistence/bolt"
err = server.AddStore(bolt.New("mochi.db", nil))
if err != nil {
log.Fatal(err)
}
server := mqtt.New(nil)
err := server.AddHook(new(auth.Hook), &auth.Options{
Ledger: &auth.Ledger{
Auth: auth.AuthRules{ // Auth disallows all by default
{Username: "peach", Password: "password1", Allow: true},
{Username: "melon", Password: "password2", Allow: true},
{Remote: "127.0.0.1:*", Allow: true},
{Remote: "localhost:*", Allow: true},
},
ACL: auth.ACLRules{ // ACL allows all by default
{Remote: "127.0.0.1:*"}, // local superuser allow all
{
// user melon can read and write to their own topic
Username: "melon", Filters: auth.Filters{
"melon/#": auth.ReadWrite,
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
},
},
{
// Otherwise, no clients have publishing permissions
Filters: auth.Filters{
"#": auth.ReadOnly,
"updates/#": auth.Deny,
},
},
},
}
})
```
The ledger can also be stored as JSON or YAML and loaded using the Data field:
```go
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice: yaml or json
})
```
See [examples/auth/encoded/main.go](examples/auth/encoded/main.go) for more information.
### Persistent Storage
#### Redis
A basic Redis storage hook is available which provides persistence for the broker. It can be added to the server in the same fashion as any other hook, with several options. It uses github.com/go-redis/redis/v8 under the hook, and is completely configurable through the Options value.
```go
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
Addr: "localhost:6379", // default redis address
Password: "", // your password
DB: 0, // your redis db
},
})
if err != nil {
log.Fatal(err)
}
```
For more information on how the redis hook works, or how to use it, see the [examples/persistence/redis/main.go](examples/persistence/redis/main.go) or [hooks/storage/redis](hooks/storage/redis) code.
#### Badger DB
There's also a BadgerDB storage hook if you prefer file based storage. It can be added and configured in much the same way as the other hooks (with somewhat less options).
```go
err := server.AddHook(new(badger.Hook), &badger.Options{
Path: badgerPath,
})
if err != nil {
log.Fatal(err)
}
```
For more information on how the badger hook works, or how to use it, see the [examples/persistence/badger/main.go](examples/persistence/badger/main.go) or [hooks/storage/badger](hooks/storage/badger) code.
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
## Developing with Event Hooks
Many hooks are available for interacting with the broker and client lifecycle.
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
| Function | Usage |
| -------------------------- | -- |
| OnStarted | Called when the server has successfully started.|
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Packet Injection
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
```go
cl := mqtt.NewInlineClient("inline", "local")
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
})
```
> MQTT packets still need to be correctly formed, so refer our [the test packets catalogue](packets/tpackets.go) and [MQTTv5 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) for inspiration.
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
```
go run --cover ./...
```
> Persistence is on-demand (not flushed) and will potentially reduce throughput when compared to the standard in-memory store. Only use it if you need to maintain state through restarts.
#### Paho Interoperability Test
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder.
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the mqtt v5 and v3 tests with `python3 client_test5.py` from the _interoperability_ folder.
> Note that there are currently a number of outstanding issues regarding false negatives in the paho suite, and as such, certain compatibility modes are enabled in the `paho/main.go` example.
## Performance Benchmarks
Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQX, and others.
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
> Benchmarks are provided as a general performance expectation guideline only.
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
## Stargazers over time 🥰
[![Stargazers over time](https://starchart.cc/mochi-co/mqtt.svg)](https://starchart.cc/mochi-co/mqtt)
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
## Contributions
Contributions and feedback are both welcomed and encouraged! Open an [issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

571
clients.go Normal file
View File

@@ -0,0 +1,571 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/v2/packets"
)
const (
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
)
// ReadFn is the function signature for the function used for reading and processing new packets.
type ReadFn func(*Client, packets.Packet) error
// Clients contains a map of the clients known by the broker.
type Clients struct {
internal map[string]*Client // clients known by the broker, keyed on client id.
sync.RWMutex
}
// NewClients returns an instance of Clients.
func NewClients() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
defer cl.Unlock()
cl.internal[val.ID] = val
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
m := map[string]*Client{}
for k, v := range cl.internal {
m[k] = v
}
return m
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
defer cl.RUnlock()
val, ok := cl.internal[id]
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
defer cl.RUnlock()
val := len(cl.internal)
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
defer cl.Unlock()
delete(cl.internal, id)
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
cl.RLock()
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
clients = append(clients, client)
}
}
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the clinet
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
}
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // client is an inline programmetic client
}
// ClientProperties contains the properties which define the client behaviour.
type ClientProperties struct {
Props packets.Properties
Will Will
Username []byte
ProtocolVersion byte
Clean bool
}
// Will contains the last will and testament details for a client connection.
type Will struct {
Payload []byte // -
User []packets.UserProperty // -
TopicName string // -
Flag uint32 // 0,1
WillDelayInterval uint32 // -
Qos byte // -
Retain bool // -
}
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, o *ops) *Client {
cl := &Client{
Net: ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
ops: o,
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// NewInlineClient returns a client used when publishing from the embedding system.
func NewInlineClient(id, remote string) *Client {
return &Client{
ID: id,
Net: ClientConnection{
Remote: remote,
Inline: true,
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// newClientStub returns an instance of Client with minimal initializations, such as
// restoring client data from a db. In particular, the client is marked as offline (done).
func newClientStub() *Client {
return &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
done: 1,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
cl.Properties.ProtocolVersion = pk.ProtocolVersion
cl.Properties.Username = pk.Connect.Username
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Retain: pk.Connect.WillRetain,
Payload: pk.Connect.WillPayload,
TopicName: pk.Connect.WillTopic,
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
User: pk.Connect.WillProperties.User,
}
if pk.Properties.SessionExpiryIntervalFlag &&
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
}
if pk.Connect.WillFlag {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
// NextPacketID returns the next available (unused) packet id for the client.
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
overflowed := false
for {
if i >= 65535 {
overflowed = true
i = 1
} else {
i++
}
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
break
}
}
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
func (cl *Client) ResendInflightMessages(force bool) error {
if cl.State.Inflight.Len() == 0 {
return nil
}
for _, tk := range cl.State.Inflight.GetAll(false) {
if tk.FixedHeader.Type == packets.Publish {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
err := cl.WritePacket(tk)
if err != nil {
return err
}
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosComplete(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
return nil
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
}
}
}
return deleted
}
// Read reads incoming packets from the connected client and transforms them into
// packets to be handled by the packetHandler.
func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
return ErrConnectionClosed
}
b, err := cl.Net.bconn.ReadByte()
if err != nil {
return err
}
err = fh.Decode(b)
if err != nil {
return err
}
var bu int
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
pk.FixedHeader = *fh
p := make([]byte, pk.FixedHeader.Remaining)
n, err := io.ReadFull(cl.Net.bconn, p)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Disconnect:
err = pk.DisconnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Auth:
err = pk.AuthDecode(px)
default:
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
}
if err != nil {
return pk, err
}
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
return nil
}
defer cl.refreshDeadline(cl.State.keepalive)
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
pk.ProtocolVersion = cl.Properties.ProtocolVersion
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
}
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
var err error
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
case packets.Auth:
err = pk.AuthEncode(buf)
default:
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
}
if err != nil {
return err
}
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesSent, n)
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
if pk.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err
}

711
clients_test.go Normal file
View File

@@ -0,0 +1,711 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = NewClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
},
})
cl.ID = "mochi"
cl.State.Inflight.maximumSendQuota = 5
cl.State.Inflight.sendQuota = 5
cl.State.Inflight.maximumReceiveQuota = 10
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
return
}
func TestNewInflights(t *testing.T) {
require.NotNil(t, NewInflights().internal)
}
func TestNewClients(t *testing.T) {
cl := NewClients()
require.NotNil(t, cl.internal)
}
func TestClientsAdd(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
}
func TestClientsGet(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
client, ok := cl.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, "t1", client.ID)
}
func TestClientsGetAll(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
clients := cl.GetAll()
require.Len(t, clients, 3)
}
func TestClientsLen(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Equal(t, 2, cl.Len())
}
func TestClientsDelete(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
cl.Delete("t1")
_, ok := cl.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, cl.internal["t1"])
}
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
clients := cl.GetByListener("tcp1")
require.NotEmpty(t, clients)
require.Equal(t, 1, len(clients))
require.Equal(t, "tcp1", clients[0].Net.Listener)
}
func TestNewClient(t *testing.T) {
cl, _, _ := newClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Nil(t, cl.StopCause())
}
func TestNewClientStub(t *testing.T) {
cl := newClientStub()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
}
func TestNewInlineClient(t *testing.T) {
cl := NewInlineClient("inline", "local")
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
require.Equal(t, "inline", cl.ID)
require.Equal(t, "local", cl.Net.Remote)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillTopic: "lwt",
WillPayload: []byte("lol gg"),
WillQos: 1,
WillRetain: false,
},
Properties: packets.Properties{
ReceiveMaximum: uint16(5),
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillProperties: packets.Properties{
WillDelayInterval: 200,
},
},
Properties: packets.Properties{
SessionExpiryInterval: 100,
SessionExpiryIntervalFlag: true,
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(2), i)
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(3), i)
// Skip over overflow
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
atomic.StoreUint32(&cl.State.packetID, 65534)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
i, err := cl.NextPacketID()
require.Equal(t, uint32(0), i)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newClient()
cl.State.packetID = uint32(65534)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(65535), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
go func() {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, pk1.RawBytes, buf)
}
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newClient()
r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
err := cl.ResendInflightMessages(true)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.Equal(t, io.EOF, err)
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
var pks []packets.Packet
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
pks = append(pks, pk)
return nil
})
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 2, len(pks))
require.Equal(t, []packets.Packet{
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
},
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
},
}, pks)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
require.NoError(t, <-o)
}
func TestClientStop(t *testing.T) {
cl, _, _ := newClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
r.Close()
}()
cl.Net.bconn = nil
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, ErrConnectionClosed, err)
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
})
require.Error(t, err)
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
pk, err := cl.ReadPacket(fh)
require.NoError(t, err)
require.NotNil(t, pk)
require.Equal(t, packets.Packet{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
}, pk)
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
tt := tx // avoid data race
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
if tt.Packet.ProtocolVersion == 5 {
cl.Properties.ProtocolVersion = 5
} else {
cl.Properties.ProtocolVersion = 0
}
pk, err := cl.ReadPacket(fh)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
}
})
}
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
o := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
o <- buf
}()
err := cl.WritePacket(*tt.Packet)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
cl.Stop(errClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
require.True(t,
errors.Is(err, errClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
}
}
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
require.Error(t, err)
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Type: 0,
Remaining: 11,
})
require.Error(t, err)
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
})
require.Error(t, err)
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
require.Equal(t, ErrConnectionClosed, err)
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newClient()
cl.Net.conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}
var (
pkTable = []packets.TPacketCase{
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
packets.TPacketData[packets.Puback].Get(packets.TPuback),
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
packets.TPacketData[packets.Suback].Get(packets.TSuback),
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
packets.TPacketData[packets.Auth].Get(packets.TAuth),
}
)

View File

@@ -1,17 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
@@ -28,37 +30,36 @@ func main() {
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Broker initializing..."))
fmt.Println(aurora.Cyan("TCP"), *tcpAddr)
fmt.Println(aurora.Cyan("Websocket"), *wsAddr)
fmt.Println(aurora.Cyan("$SYS Dashboard"), *infoAddr)
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
server := mqtt.New()
tcp := listeners.NewTCP("t1", *tcpAddr)
err := server.AddListener(tcp, nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", *wsAddr)
err = server.AddListener(ws, nil)
ws := listeners.NewWebsocket("ws1", *wsAddr, nil)
err = server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", *infoAddr)
err = server.AddListener(stats, nil)
stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println(aurora.BgMagenta(" Started! "))
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
server.Log.Info().Msg("main.go finished")
}

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
authRules := &auth.Ledger{
Auth: auth.AuthRules{ // Auth disallows all by default
{Username: "peach", Password: "password1", Allow: true},
{Username: "melon", Password: "password2", Allow: true},
{Remote: "127.0.0.1:*", Allow: true},
{Remote: "localhost:*", Allow: true},
},
ACL: auth.ACLRules{ // ACL allows all by default
{Remote: "127.0.0.1:*"}, // local superuser allow all
{
// user melon can read and write to their own topic
Username: "melon", Filters: auth.Filters{
"melon/#": auth.ReadWrite,
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
},
},
{
// Otherwise, no clients have publishing permissions
Filters: auth.Filters{
"#": auth.ReadOnly,
"updates/#": auth.Deny,
},
},
},
}
// you may also find this useful...
// d, _ := authRules.ToYAML()
// d, _ := authRules.ToJSON()
// fmt.Println(string(d))
server := mqtt.New(nil)
err := server.AddHook(new(auth.Hook), &auth.Options{
Ledger: authRules,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -0,0 +1,40 @@
{
"auth": [
{
"username": "peach",
"password": "password1",
"allow": true
},
{
"username": "melon",
"password": "password2",
"allow": true
},
{
"remote": "127.0.0.1:*",
"allow": false
},
{
"remote": "localhost:*",
"allow": false
}
],
"acl": [
{
"remote": "127.0.0.1:*"
},
{
"username": "melon",
"filters": {
"melon/#": 3,
"updates/#": 2
}
},
{
"filters": {
"#": 1,
"updates/#": 0
}
}
]
}

View File

@@ -0,0 +1,21 @@
auth:
- username: peach
password: password1
allow: true
- username: melon
password: password2
allow: true
# - remote: 127.0.0.1:*
# allow: true
# - remote: localhost:*
# allow: true
acl:
# 0 = deny, 1 = read only, 2 = write only, 3 = read and write
- remote: 127.0.0.1:*
- username: melon
filters:
melon/#: 3
updates/#: 2
- filters:
'#': 1
updates/#: 0

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// You can also run from top-level server.go folder:
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json
path := flag.String("path", "auth.yaml", "path to data auth file")
flag.Parse()
// Get ledger from yaml file
data, err := os.ReadFile(*path)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(nil)
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice, yaml or json
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,49 +0,0 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, nil)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

62
examples/debug/main.go Normal file
View File

@@ -0,0 +1,62 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/debug"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

128
examples/hooks/main.go Normal file
View File

@@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(ExampleHook), map[string]any{})
if err != nil {
log.Fatal(err)
}
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Demonstration of directly publishing messages to a topic via the
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := mqtt.NewInlineClient("inline", "local")
for range time.Tick(time.Second * 10) {
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
})
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}
type ExampleHook struct {
mqtt.HookBase
}
func (h *ExampleHook) ID() string {
return "events-example"
}
func (h *ExampleHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnect,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnPublished,
mqtt.OnPublish,
}, []byte{b})
}
func (h *ExampleHook) Init(config any) error {
h.Log.Info().Msg("initialised")
return nil
}
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
}
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
}
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
}
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
}
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
pkx := pk
if string(pk.Payload) == "hello" {
pkx.Payload = []byte("hello world")
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
}
return pkx, nil
}
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
}

View File

@@ -1,16 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"fmt"
"bytes"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
@@ -22,41 +25,51 @@ func main() {
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("PAHO Testing Suite"))
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
server := mqtt.New()
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(Auth),
})
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println(aurora.BgMagenta(" Started! "))
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
server.Log.Info().Msg("main.go finished")
}
// Auth is an example auth provider for the server.
type Auth struct{}
type pahoAuthHook struct {
mqtt.HookBase
}
// Auth returns true if a username and password are acceptable.
// Auth always returns true.
func (a *Auth) Authenticate(user, password []byte) bool {
func (h *pahoAuthHook) ID() string {
return "allow-all-auth"
}
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
// ACL returns true if a user has access permissions to read or write on a topic.
// ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
if topic == "test/nosubscribe" {
return false
}
return true
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
badgerPath := ".badger"
defer os.RemoveAll(badgerPath) // remove the example badger files at the end
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(badger.Hook), &badger.Options{
Path: badgerPath,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
"github.com/mochi-co/mqtt/v2/listeners"
"go.etcd.io/bbolt"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,58 +0,0 @@
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
"go.etcd.io/bbolt"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/persistence/bolt"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence"))
server := mqtt.New()
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(Auth),
})
if err != nil {
log.Fatal(err)
}
err = server.AddStore(bolt.New("mochi-test.db", &bbolt.Options{
Timeout: 500 * time.Millisecond,
}))
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
}

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
rv8 "github.com/go-redis/redis/v8"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
Addr: "localhost:6379", // default redis address
Password: "", // your password
DB: 0, // your redis db
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,16 +1,18 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
@@ -22,13 +24,22 @@ func main() {
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
// An example of configuring various server options...
options := &mqtt.Options{
// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
}
server := mqtt.New()
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(Auth),
})
server := mqtt.New(options)
// For security reasons, the default implementation disallows all connections.
// If you want to allow all connections, you must specifically allow it.
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
@@ -39,12 +50,9 @@ func main() {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,17 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"fmt"
"crypto/tls"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
var (
@@ -55,51 +57,61 @@ func main() {
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
server := mqtt.New()
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", ":8080")
err = server.AddListener(stats, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
go server.Serve()
fmt.Println(aurora.BgMagenta(" Started! "))
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
TLSConfig: tlsConfig,
}, nil)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
server.Log.Info().Msg("main.go finished")
}

View File

@@ -1,16 +1,18 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/logrusorgru/aurora"
mqtt "github.com/mochi-co/mqtt/server"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
@@ -22,11 +24,11 @@ func main() {
done <- true
}()
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
server := mqtt.New()
ws := listeners.NewWebsocket("ws1", ":1882")
err = server.AddListener(ws, nil)
ws := listeners.NewWebsocket("ws1", ":1882", nil)
err := server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
@@ -37,11 +39,9 @@ func main() {
log.Fatal(err)
}
}()
fmt.Println(aurora.BgMagenta(" Started! "))
<-done
fmt.Println(aurora.BgRed(" Caught Signal "))
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
fmt.Println(aurora.BgGreen(" Finished "))
server.Log.Info().Msg("main.go finished")
}

101
fanpool.go Normal file
View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

89
fanpool_test.go Normal file
View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

46
go.mod
View File

@@ -1,18 +1,40 @@
module github.com/mochi-co/mqtt
module github.com/mochi-co/mqtt/v2
go 1.13
go 1.19
require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.1.0
github.com/gorilla/websocket v1.4.1
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
github.com/krylovsk/mqtt-benchmark v0.1.1 // indirect
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
github.com/rs/xid v1.2.1
github.com/stretchr/testify v1.4.0
go.etcd.io/bbolt v1.3.3
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/rs/xid v1.4.0
github.com/rs/zerolog v1.28.0
github.com/stretchr/testify v1.7.1
github.com/timshannon/badgerhold v1.0.0
go.etcd.io/bbolt v1.3.5
gopkg.in/yaml.v3 v3.0.1
)
replace github.com/mochi-co/debug => /Users/mochimochi/Development/Go/src/github.com/mochi-co/debug
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/golang/protobuf v1.5.0 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.1 // indirect
)

137
go.sum
View File

@@ -1,66 +1,143 @@
github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/GaryBoone/GoStats v0.0.0-20130122001700-1993eafbef57 h1:EUQH/F+mzJBs53c75r7R5zdM/kz7BHXoWBFsVXzadVw=
github.com/GaryBoone/GoStats v0.0.0-20130122001700-1993eafbef57/go.mod h1:5zDl2HgTb/k5i9op9y6IUSiuVkZFpUrWGQbZc9tNR40=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
github.com/asdine/storm/v3 v3.1.0 h1:yrpSNS+E7ef5Y5KjyZDeyW72Dl17lYG7oZ7eUoWvo5s=
github.com/asdine/storm/v3 v3.1.0/go.mod h1:letAoLCXz4UfodwNgMNILMb2oRH+su337ZfHnkRzqDA=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t2y2qayIX0=
github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts=
github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a h1:zPPuIq2jAWWPTrGt70eK/BSch+gFAGrNzecsoENgu2o=
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/krylovsk/mqtt-benchmark v0.1.1 h1:ErPkllMHrX5dAGzWBNH+yYLY/kIufGEBWOPOmOSilmg=
github.com/krylovsk/mqtt-benchmark v0.1.1/go.mod h1:ud2sw14D+GdIeJGOh9ZZnBfjAVXzPyHQl58Yagk5P9w=
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23 h1:Wp7NjqGKGN9te9N/rvXYRhlVcrulGdxnz8zadXWs7fc=
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8=
github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191011234655-491137f69257/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179 h1:IqVhUQp5B9ARnZUcfqXy6zP+A+YuPpP7IFo8gFeCOzU=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

784
hooks.go Normal file
View File

@@ -0,0 +1,784 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
const (
SetOptions byte = iota
OnSysInfoTick
OnStarted
OnStopped
OnConnectAuthenticate
OnACLCheck
OnConnect
OnSessionEstablished
OnDisconnect
OnAuthPacket
OnPacketRead
OnPacketEncode
OnPacketSent
OnPacketProcessed
OnSubscribe
OnSubscribed
OnSelectSubscribers
OnUnsubscribe
OnUnsubscribed
OnPublish
OnPublished
OnRetainMessage
OnQosPublish
OnQosComplete
OnQosDropped
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
StoredRetainedMessages
StoredSysInfo
)
var (
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
ErrInvalidConfigType = errors.New("invalid config type provided")
)
// Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker.
type Hook interface {
ID() string
Provides(b byte) bool
Init(config any) error
Stop() error
SetOpts(l *zerolog.Logger, o *HookOptions)
OnStarted()
OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet)
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
StoredRetainedMessages() ([]storage.Message, error)
StoredSysInfo() (storage.SystemInfo, error)
}
// HookOptions contains values which are inherited from the server on initialisation.
type HookOptions struct {
Capabilities *Capabilities
}
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
}
// Len returns the number of hooks added.
func (h *Hooks) Len() int64 {
return atomic.LoadInt64(&h.qty)
}
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hb := range b {
if hook.Provides(hb) {
return true
}
}
}
return false
}
// Add adds and initializes a new hook.
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
}
h.wg.Done()
}
}()
h.wg.Wait()
}
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
}
}
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
}
}
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
}
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
}
}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
}
}
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
return pk, err
}
pkx = npk
}
}
return
}
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
}
return pk
}
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
}
}
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
}
}
// OnSubscribe is called when a client subscribes to one or more filters. This method
// differs from OnSubscribed in that it allows you to modify the subscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
}
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
}
}
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
// shared subscription subscribers have been selected. This hook can be used to programmatically
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
}
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
}
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnMessage
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
}
}
// OnQosComplete is called when the Qos flow for a message has been completed.
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
}
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
continue
}
will = mlwt
}
}
return will
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
}
}
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
}
}
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
}
}
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
return v, err
}
if v.Version != "" {
return v, nil
}
}
}
return
}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
// An implementation of this method MUST be used to allow or deny access to the
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
}
}
}
return false
}
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
// An implementation of this method MUST be used to allow or deny access to the
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
}
}
}
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Hook
Log *zerolog.Logger
Opts *HookOptions
}
// ID returns the ID of the hook.
func (h *HookBase) ID() string {
return "base"
}
// Provides indicates which methods a hook provides. The default is none - this method
// should be overridden by the embedding hook.
func (h *HookBase) Provides(b byte) bool {
return false
}
// Init performs any pre-start initializations for the hook, such as connecting to databases
// or opening files.
func (h *HookBase) Init(config any) error {
return nil
}
// SetOpts is called by the server to propagate internal values and generally should
// not be called manually.
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
// Stop is called to gracefully shutdown the hook.
func (h *HookBase) Stop() error {
return nil
}
// OnStarted is called when the server starts.
func (h *HookBase) OnStarted() {}
// OnStopped is called when the server stops.
func (h *HookBase) OnStopped() {}
// OnSysInfoTick is called when the server publishes system info.
func (h *HookBase) OnSysInfoTick(*system.Info) {}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return false
}
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
// OnAuthPacket is called when an auth packet is received from the client.
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketRead is called when a packet is received.
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnPacketSent is called immediately after a packet is written to a client.
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
// OnPacketProcessed is called immediately after a packet from a client is processed.
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
// OnSubscribe is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
// OnSelectSubscribers is called when selecting subscribers to receive a message.
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
// OnPublish is called when a client publishes a message.
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
// OnClientExpired is called when a client session has expired.
func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return
}
// StoredSubscriptions returns all subcriptions from a store.
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
return
}
// StoredInflightMessages returns all inflight messages from a store.
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
return
}
// StoredRetainedMessages returns all retained messages from a store.
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
return
}
// StoredSysInfo returns a set of system info values.
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
return
}

41
hooks/auth/allow_all.go Normal file
View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// AllowHook is an authentication hook which allows connection access
// for all users and read and write access to all topics.
type AllowHook struct {
mqtt.HookBase
}
// ID returns the ID of the hook.
func (h *AllowHook) ID() string {
return "allow-all-auth"
}
// Provides indicates which hook methods this hook provides.
func (h *AllowHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// OnConnectAuthenticate returns true/allowed for all requests.
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
// OnACLCheck returns true/allowed for all checks.
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return true
}

View File

@@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestAllowAllID(t *testing.T) {
h := new(AllowHook)
require.Equal(t, "allow-all-auth", h.ID())
}
func TestAllowAllProvides(t *testing.T) {
h := new(AllowHook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublished))
}
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
}
func TestAllowAllOnACLCheck(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
}

107
hooks/auth/auth.go Normal file
View File

@@ -0,0 +1,107 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// Options contains the configuration/rules data for the auth ledger.
type Options struct {
Data []byte
Ledger *Ledger
}
// Hook is an authentication hook which implements an auth ledger.
type Hook struct {
mqtt.HookBase
config *Options
ledger *Ledger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "auth-ledger"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// Init configures the hook with the auth ledger to be used for checking.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
var err error
if h.config.Ledger != nil {
h.ledger = h.config.Ledger
} else if len(h.config.Data) > 0 {
h.ledger = new(Ledger)
err = h.ledger.Unmarshal(h.config.Data)
}
if err != nil {
return err
}
if h.ledger == nil {
h.ledger = &Ledger{
Auth: AuthRules{},
ACL: ACLRules{},
}
}
h.Log.Info().
Int("authentication", len(h.ledger.Auth)).
Int("acl", len(h.ledger.ACL)).
Msg("loaded auth rules")
return nil
}
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
// in the auth ledger.
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
if _, ok := h.ledger.AuthOk(cl, pk); ok {
return true
}
h.Log.Info().
Str("username", string(pk.Connect.Username)).
Str("remote", cl.Net.Remote).
Msg("client failed authentication check")
return false
}
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
// or publish to a given topic.
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
return true
}
h.Log.Debug().
Str("client", cl.ID).
Str("username", string(cl.Properties.Username)).
Str("topic", topic).
Msg("client failed allowed ACL check")
return false
}

213
hooks/auth/auth_test.go Normal file
View File

@@ -0,0 +1,213 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"os"
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
// func teardown(t *testing.T, path string, h *Hook) {
// h.Stop()
// }
func TestBasicID(t *testing.T) {
h := new(Hook)
require.Equal(t, "auth-ledger", h.ID())
}
func TestBasicProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublish))
}
func TestBasicInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestBasicInitDefaultConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
}
func TestBasicInitWithLedgerPointer(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := &Ledger{
Auth: []AuthRule{
{
Remote: "127.0.0.1",
Allow: true,
},
},
ACL: []ACLRule{
{
Remote: "127.0.0.1",
Filters: Filters{
"#": ReadWrite,
},
},
},
}
err := h.Init(&Options{
Ledger: ln,
})
require.NoError(t, err)
require.Same(t, ln, h.ledger)
}
func TestBasicInitWithLedgerJSON(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerJSON,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerYAML(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerYAML,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: []byte("fdsfdsafasd"),
})
require.Error(t, err)
}
func TestOnConnectAuthenticate(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{},
packets.Packet{},
))
}
func TestOnACL(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"mochi/info",
true,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"d/j/f",
true,
))
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
false,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
true,
))
}

231
hooks/auth/ledger.go Normal file
View File

@@ -0,0 +1,231 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"encoding/json"
"strings"
"sync"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"gopkg.in/yaml.v3"
)
const (
Deny Access = iota // user cannot access the topic
ReadOnly // user can only subscribe to the topic
WriteOnly // user can only publish to the topic
ReadWrite // user can both publish and subscribe to the topic
)
// Access determines the read/write privileges for an ACL rule.
type Access byte
// Users contains a map of access rules for specific users, keyed on username.
type Users map[string]UserRule
// UserRule defines a set of access rules for a specific user.
type UserRule struct {
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
}
// AuthRules defines generic access rules applicable to all users.
type AuthRules []AuthRule
type AuthRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
}
// ACLRules defines generic topic or filter access rules applicable to all users.
type ACLRules []ACLRule
// ACLRule defines access rules for a specific topic or filter.
type ACLRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
}
// Filters is a map of Access rules keyed on filter.
type Filters map[RString]Access
// RString is a rule value string.
type RString string
// Matches returns true if the rule matches a given string.
func (r RString) Matches(a string) bool {
rr := string(r)
if r == "" || r == "*" || a == rr {
return true
}
i := strings.Index(rr, "*")
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
return true
}
return false
}
// FilterMatches returns true if a filter matches a topic rule.
func (f RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(f), a)
return ok
}
// MatchTopic checks if a given topic matches a filter, accounting for filter
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
filterParts := strings.Split(filter, "/")
topicParts := strings.Split(topic, "/")
elements = make([]string, 0)
for i := 0; i < len(filterParts); i++ {
if i >= len(topicParts) {
matched = false
return
}
if filterParts[i] == "+" {
elements = append(elements, topicParts[i])
continue
}
if filterParts[i] == "#" {
matched = true
elements = append(elements, strings.Join(topicParts[i:], "/"))
return
}
if filterParts[i] != topicParts[i] {
matched = false
return
}
}
return elements, true
}
// Ledger is an auth ledger containing access rules for users and topics.
type Ledger struct {
sync.Mutex `json:"-" yaml:"-"`
Users Users `json:"users" yaml:"users"`
Auth AuthRules `json:"auth" yaml:"auth"`
ACL ACLRules `json:"acl" yaml:"acl"`
}
// Update updates the internal values of the ledger.
func (l *Ledger) Update(ln *Ledger) {
l.Lock()
defer l.Unlock()
l.Auth = ln.Auth
l.ACL = ln.ACL
}
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
u.Password != "" &&
u.Password == RString(pk.Connect.Password) {
return 0, !u.Disallow
}
}
// If there's no users map, or no user was found, attempt to find a matching
// rule (which may also contain a user).
for n, rule := range l.Auth {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Password.Matches(string(pk.Connect.Password)) &&
rule.Remote.Matches(cl.Net.Remote) {
return n, rule.Allow
}
}
return 0, false
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the write bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
for filter, access := range u.ACL {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
for n, rule := range l.ACL {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Remote.Matches(cl.Net.Remote) {
if len(rule.Filters) == 0 {
return n, true
}
for filter, access := range rule.Filters {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
return 0, true
}
// ToJSON encodes the values into a JSON string.
func (l *Ledger) ToJSON() (data []byte, err error) {
return json.Marshal(l)
}
// ToYAML encodes the values into a YAML string.
func (l *Ledger) ToYAML() (data []byte, err error) {
return yaml.Marshal(l)
}
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
func (l *Ledger) Unmarshal(data []byte) error {
l.Lock()
defer l.Unlock()
if len(data) == 0 {
return nil
}
if data[0] == '{' {
return json.Unmarshal(data, l)
}
return yaml.Unmarshal(data, &l)
}

610
hooks/auth/ledger_test.go Normal file
View File

@@ -0,0 +1,610 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
var (
checkLedger = Ledger{
Users: Users{ // users are allowed by default
"mochi-co": {
Password: "melon",
ACL: Filters{
"d/+/f": Deny,
"mochi-co/#": ReadWrite,
"readonly": ReadOnly,
},
},
"suspended-username": {
Password: "any",
Disallow: true,
},
"mochi": { // ACL only, will defer to AuthRules for authentication
ACL: Filters{
"special/mochi": ReadOnly,
"secret/mochi": Deny,
"ignored": ReadWrite,
},
},
},
Auth: AuthRules{
{Username: "banned-user"}, // never allow specific username
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
{Remote: "123.123.123.123"}, // disallow any from specific address
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
},
ACL: ACLRules{
{
Username: "mochi", // allow matching user/pass
Filters: Filters{
"a/b/c": Deny,
"d/+/f": Deny,
"mochi/#": ReadWrite,
"updates/#": WriteOnly,
"readonly": ReadOnly,
"ignored": Deny,
},
},
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
{Remote: "001.002.003.004"}, // Allow all with no filter
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
},
}
)
func TestRStringMatches(t *testing.T) {
require.True(t, RString("*").Matches("any"))
require.True(t, RString("*").Matches(""))
require.True(t, RString("").Matches("any"))
require.True(t, RString("").Matches(""))
require.False(t, RString("no").Matches("any"))
require.False(t, RString("no").Matches(""))
}
func TestCanAuthenticate(t *testing.T) {
tt := []struct {
desc string
client *mqtt.Client
pk packets.Packet
n int
ok bool
}{
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "deny client from address",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("not-mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.144.155.166",
},
},
pk: packets.Packet{},
ok: false,
n: 3,
},
{
desc: "allow remote wildcard",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.0.0.1",
},
},
pk: packets.Packet{},
ok: true,
n: 4,
},
{
desc: "never allow username",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("banned-user"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{},
ok: false,
n: 0,
},
{
desc: "matching user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi-co"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 0,
},
{
desc: "never user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("suspended-user"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
ok: false,
n: 0,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.AuthOk(d.client, d.pk)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestCanACL(t *testing.T) {
tt := []struct {
client *mqtt.Client
desc string
topic string
n int
write bool
ok bool
}{
{
desc: "allow normal write on any other filter",
client: &mqtt.Client{},
topic: "default/acl/write/access",
write: true,
ok: true,
},
{
desc: "allow normal read on any other filter",
client: &mqtt.Client{},
topic: "default/acl/read/access",
write: false,
ok: true,
},
{
desc: "deny user on literal filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "a/b/c",
},
{
desc: "deny user on partial filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "d/j/f",
},
{
desc: "allow read/write to user path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "mochi/read/write",
write: true,
ok: true,
},
{
desc: "deny read on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/no/reading",
write: false,
ok: false,
},
{
desc: "deny read on write-only path ext",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: false,
ok: false,
},
{
desc: "allow read on not-acl path (no #)",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates",
write: false,
ok: true,
},
{
desc: "allow write on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: true,
ok: true,
},
{
desc: "deny write on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: true,
ok: false,
},
{
desc: "allow read on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: false,
ok: true,
},
{
desc: "allow $sys access to localhost",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "localhost",
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 1,
},
{
desc: "allow $sys access to admin",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("admin"),
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 2,
},
{
desc: "deny $sys access to all others",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "$SYS/test",
write: false,
ok: false,
n: 4,
},
{
desc: "allow all with no filter",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "001.002.003.004",
},
},
topic: "any/path",
write: true,
ok: true,
n: 3,
},
{
desc: "use users embedded acl deny",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "secret/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl any",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "any/mochi",
write: true,
ok: true,
},
{
desc: "use users embedded acl write on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl read on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: false,
ok: true,
},
{
desc: "preference users embedded acl",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "ignored",
write: true,
ok: true,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestMatchTopic(t *testing.T) {
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "d"}, el)
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "c", "d"}, el)
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
require.True(t, matched)
require.Equal(t, []string{"things/yeah"}, el)
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
require.True(t, matched)
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
el, matched = MatchTopic("test", "test")
require.True(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("things/stuff//", "things/stuff/")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("t", "t2")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic(" ", " ")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
}
var (
ledgerStruct = Ledger{
Users: Users{
"mochi": {
Password: "peach",
ACL: Filters{
"readonly": ReadOnly,
"deny": Deny,
},
},
},
Auth: AuthRules{
{
Client: "*",
Username: "mochi-co",
Password: "melon",
Remote: "192.168.1.*",
Allow: true,
},
},
ACL: ACLRules{
{
Client: "*",
Username: "mochi-co",
Remote: "127.*",
Filters: Filters{
"readonly": ReadOnly,
"writeonly": WriteOnly,
"readwrite": ReadWrite,
"deny": Deny,
},
},
},
}
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
ledgerYAML = []byte(`users:
mochi:
password: peach
acl:
deny: 0
readonly: 1
auth:
- client: '*'
username: mochi-co
remote: 192.168.1.*
password: melon
allow: true
acl:
- client: '*'
username: mochi-co
remote: 127.*
filters:
deny: 0
readonly: 1
readwrite: 3
writeonly: 2
`)
)
func TestLedgerUpdate(t *testing.T) {
old := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
},
}
new := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(new)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, new, old)
}
func TestLedgerToJSON(t *testing.T) {
data, err := ledgerStruct.ToJSON()
require.NoError(t, err)
require.Equal(t, ledgerJSON, data)
}
func TestLedgerToYAML(t *testing.T) {
data, err := ledgerStruct.ToYAML()
require.NoError(t, err)
require.Equal(t, ledgerYAML, data)
}
func TestLedgerUnmarshalFromYAML(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerYAML)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalFromJSON(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerJSON)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalNil(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal([]byte{})
require.NoError(t, err)
require.Equal(t, new(Ledger), l)
}

250
hooks/debug/debug.go Normal file
View File

@@ -0,0 +1,250 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
)
// Options contains configuration settings for the debug output.
type Options struct {
ShowPacketData bool // include decoded packet data (default false)
ShowPings bool // show ping requests and responses (default false)
ShowPasswords bool // show connecting user passwords (default false)
}
// Hook is a debugging hook which logs additional low-level information from the server.
type Hook struct {
mqtt.HookBase
config *Options
Log *zerolog.Logger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "debug"
}
// Provides indicates that this hook provides all methods.
func (h *Hook) Provides(b byte) bool {
return true
}
// Init is called when the hook is initialized.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
return nil
}
// SetOpts is called when the hook receives inheritable server parameters.
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
h.Log = l
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
}
// Stop is called when the hook is stopped.
func (h *Hook) Stop() error {
h.Log.Debug().Str("method", "Stop").Send()
return nil
}
// OnStarted is called when the server starts.
func (h *Hook) OnStarted() {
h.Log.Debug().Str("method", "OnStarted").Send()
}
// OnStopped is called when the server stops.
func (h *Hook) OnStopped() {
h.Log.Debug().Str("method", "OnStopped").Send()
}
// OnPacketRead is called when a new packet is received from a client.
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return pk, nil
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
return pk, nil
}
// OnPacketSent is called when a packet is sent to a client.
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
}
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
}
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
}
// OnQosDropped is called the Qos flow for a message expires.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
}
// OnLWTSent is called when a will message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
}
// OnRetainedExpired is called when the server clears expired retained messages.
func (h *Hook) OnRetainedExpired(filter string) {
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
}
// OnClientExpired is called when the server clears an expired client.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
}
// StoredClients is called when the server restores clients from a store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// StoredClients is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug().
Str("method", "StoredSubscriptions").
Send()
return v, nil
}
// StoredClients is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredRetainedMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredInflightMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// packetMeta adds additional type-specific metadata to the debug logs.
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
m := map[string]any{}
switch pk.FixedHeader.Type {
case packets.Connect:
m["id"] = pk.Connect.ClientIdentifier
m["clean"] = pk.Connect.Clean
m["keepalive"] = pk.Connect.Keepalive
m["version"] = pk.ProtocolVersion
m["username"] = string(pk.Connect.Username)
if h.config.ShowPasswords {
m["password"] = string(pk.Connect.Password)
}
if pk.Connect.WillFlag {
m["will_topic"] = pk.Connect.WillTopic
m["will_payload"] = string(pk.Connect.WillPayload)
}
case packets.Publish:
m["topic"] = pk.TopicName
m["payload"] = string(pk.Payload)
m["raw"] = pk.Payload
m["qos"] = pk.FixedHeader.Qos
m["id"] = pk.PacketID
case packets.Connack:
fallthrough
case packets.Disconnect:
fallthrough
case packets.Puback:
fallthrough
case packets.Pubrec:
fallthrough
case packets.Pubrel:
fallthrough
case packets.Pubcomp:
m["id"] = pk.PacketID
m["reason"] = int(pk.ReasonCode)
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
m["reason_string"] = pk.Properties.ReasonString
}
case packets.Subscribe:
f := map[string]int{}
ids := map[string]int{}
for _, v := range pk.Filters {
f[v.Filter] = int(v.Qos)
ids[v.Filter] = v.Identifier
}
m["filters"] = f
m["subids"] = f
case packets.Unsubscribe:
f := []string{}
for _, v := range pk.Filters {
f = append(f, v.Filter)
}
m["filters"] = f
case packets.Suback:
fallthrough
case packets.Unsuback:
r := []int{}
for _, v := range pk.ReasonCodes {
r = append(r, int(v))
}
m["reasons"] = r
case packets.Auth:
// tbd
}
if h.config.ShowPacketData {
m["packet"] = pk
}
return m
}

View File

@@ -0,0 +1,484 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"bytes"
"errors"
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/timshannon/badgerhold"
)
const (
// defaultDbFile is the default file path for the badger db file.
defaultDbFile = ".badger"
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the BadgerDB instance.
type Options struct {
Options *badgerhold.Options
Path string
}
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the BadgerDB instance.
db *badgerhold.Store // the BadgerDB instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "badger-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the badger instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
options := badgerhold.DefaultOptions
options.Dir = h.config.Path
options.ValueDir = h.config.Path
options.Logger = h
var err error
h.db, err = badgerhold.Open(options)
if err != nil {
return err
}
return nil
}
// Stop closes the badger instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
}
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
h.updateClient(cl)
if !expire {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
if err != nil {
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
PacketID: pk.PacketID,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Get(storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// Errorf satisfies the badger interface for an error logger.
func (h *Hook) Errorf(m string, v ...interface{}) {
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Warningf satisfies the badger interface for a warning logger.
func (h *Hook) Warningf(m string, v ...interface{}) {
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Infof satisfies the badger interface for an info logger.
func (h *Hook) Infof(m string, v ...interface{}) {
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Debugf satisfies the badger interface for a debug logger.
func (h *Hook) Debugf(m string, v ...interface{}) {
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}

View File

@@ -0,0 +1,695 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"errors"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/timshannon/badgerhold"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
h.db.Badger().Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "badger-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.Get(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.Get(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.Get(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.Get(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, badgerhold.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Upsert(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.Get(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.Get(m.ID, r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.Get(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.Get(inflightKey(client, pk), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.Get(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Debugf("test", 1, 2, 3)
}

486
hooks/storage/bolt/bolt.go Normal file
View File

@@ -0,0 +1,486 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
"bytes"
"errors"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
"go.etcd.io/bbolt"
)
const (
// defaultDbFile is the default file path for the boltdb file.
defaultDbFile = "bolt.db"
// defaultTimeout is the default time to hold a connection to the file.
defaultTimeout = 250 * time.Millisecond
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
Options *bbolt.Options
Path string
}
// Hook is a persistent storage hook based using boltdb file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "bolt-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the boltdb instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &bbolt.Options{
Timeout: defaultTimeout,
}
}
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
var err error
h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec))
if err != nil {
return err
}
return nil
}
// Stop closes the boltdb instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.DeleteStruct(&storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
Msg("failed to delete client")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.DeleteStruct(&storage.Message{
ID: retainedKey(pk.TopicName),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", retainedKey(pk.TopicName)).
Msg("failed to delete retained publish")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save retained publish data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Message{
ID: inflightKey(cl, pk),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", inflightKey(cl, pk)).
Msg("failed to delete inflight data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Interface("data", in).
Msg("failed to save $SYS data")
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.ClientKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.SubscriptionKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.RetainedKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.One("ID", storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}

View File

@@ -0,0 +1,730 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"errors"
"os"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/asdine/storm/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "bolt-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestInitBadPath(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Path: "..",
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Save(&storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.One("ID", clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.One("ID", clientKey, r)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Save(m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.One("ID", m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.One("ID", m.ID, r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.One("ID", inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.One("ID", inflightKey(client, pk), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.One("ID", storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with sys info
err = h.db.Save(&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

View File

@@ -0,0 +1,529 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
redis "github.com/go-redis/redis/v8"
)
// defaultAddr is the default address to the redis service.
const defaultAddr = "localhost:6379"
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
const defaultHPrefix = "mochi-"
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
HPrefix string
Options *redis.Options
}
// Hook is a persistent storage hook based using Redis as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for connecting to the Redis instance.
db *redis.Client // the Redis instance
ctx context.Context // a context for the connection
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "redis-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnWillSent,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// hKey returns a hash set key with a unique prefix.
func (h *Hook) hKey(s string) string {
return h.config.HPrefix + s
}
// Init initializes and connects to the redis service.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
h.ctx = context.Background()
if config == nil {
config = &Options{
Options: &redis.Options{
Addr: defaultAddr,
},
}
}
h.config = config.(*Options)
if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix
}
h.Log.Info().
Str("address", h.config.Options.Addr).
Str("username", h.config.Options.Username).
Int("password-len", len(h.config.Options.Password)).
Int("db", h.config.Options.DB).
Msg("connecting to redis service")
h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result()
if err != nil {
return fmt.Errorf("failed to ping service: %w", err)
}
h.Log.Info().Msg("connected to redis service")
return nil
}
// Close closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info().Msg("disconnecting from redis service")
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
return
}
for _, row := range rows {
var d storage.Client
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
}
v = append(v, d)
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
return
}
for _, row := range rows {
var d storage.Subscription
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
}
v = append(v, d)
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
}
v = append(v, d)
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
v = append(v, d)
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return
}
if err = v.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
}
return v, nil
}

View File

@@ -0,0 +1,811 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"os"
"sort"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
miniredis "github.com/alicebob/miniredis/v2"
redis "github.com/go-redis/redis/v8"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func newHook(t *testing.T, addr string) *Hook {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: addr,
},
})
require.NoError(t, err)
return h
}
func teardown(t *testing.T, h *Hook) {
if h.db != nil {
err := h.db.FlushAll(h.ctx).Err()
require.NoError(t, err)
h.Stop()
}
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, "cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, "a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, "cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Equal(t, "redis-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestHKey(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.SetOpts(&logger, nil)
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
}
func TestInitUseDefaults(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, defaultAddr)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h)
require.Equal(t, defaultHPrefix, h.config.HPrefix)
require.Equal(t, defaultAddr, h.config.Options.Addr)
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitBadAddr(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: "abc:123",
},
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r2.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
require.NoError(t, err)
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, clientKey, r.ID)
h.OnClientExpired(cl)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.Error(t, err)
require.ErrorIs(t, redis.Nil, err)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnSubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
require.NoError(t, err)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnQosPublishNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with clients
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with subscriptions
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with sys info
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
}).Err()
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

164
hooks/storage/storage.go Normal file
View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"encoding/json"
"errors"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
)
const (
SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store
SysInfoKey = "SYS" // unique key to denote server system information in a store
RetainedKey = "RET" // unique key to denote retained messages in a store
InflightKey = "IFM" // unique key to denote inflight messages in a store
ClientKey = "CL" // unique key to denote clients in a store
)
var (
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
ErrDBFileNotOpen = errors.New("db file not open")
)
// Client is a storable representation of an mqtt client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
Username []byte `json:"username"` // the username of the client
ID string `json:"id" storm:"id"` // the client id / storage key
T string `json:"t"` // the data type (client)
Remote string `json:"remote"` // the remote address of the client
Listener string `json:"listener"` // the listener the client connected on
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
Clean bool `json:"clean"` // if the client requested a clean start/session
}
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
type ClientProperties struct {
AuthenticationData []byte `json:"authenticationData"`
User []packets.UserProperty `json:"user"`
AuthenticationMethod string `json:"authenticationMethod"`
SessionExpiryInterval uint32 `json:"sessionExpiryInterval"`
MaximumPacketSize uint32 `json:"maximumPacketSize"`
ReceiveMaximum uint16 `json:"receiveMaximum"`
TopicAliasMaximum uint16 `json:"topicAliasMaximum"`
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag"`
RequestProblemInfo byte `json:"requestProblemInfo"`
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag"`
RequestResponseInfo byte `json:"requestResponseInfo"`
}
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
type ClientWill struct {
Payload []byte `json:"payload"`
User []packets.UserProperty `json:"user"`
TopicName string `json:"topicName"`
Flag uint32 `json:"flag"`
WillDelayInterval uint32 `json:"willDelayInterval"`
Qos byte `json:"qos"`
Retain bool `json:"retain"`
}
// MarshalBinary encodes the values into a json string.
func (d Client) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Client) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Message is a storable representation of an MQTT message (specifically publish).
type Message struct {
Properties MessageProperties `json:"properties"` // -
Payload []byte `json:"payload"` // the message payload (if retained)
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
Origin string `json:"origin"` // the id of the client who sent the message
TopicName string `json:"topic_name"` // the topic the message was sent to (if retained)
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
Created int64 `json:"created"` // the time the message was created in unixtime
Sent int64 `json:"sent"` // the last time the message was sent (for retries) in unixtime (if inflight)
PacketID uint16 `json:"packet_id"` // the unique id of the packet (if inflight)
}
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
type MessageProperties struct {
CorrelationData []byte `json:"correlationData"`
SubscriptionIdentifier []int `json:"subscriptionIdentifier"`
User []packets.UserProperty `json:"user"`
ContentType string `json:"contentType"`
ResponseTopic string `json:"responseTopic"`
MessageExpiryInterval uint32 `json:"messageExpiry"`
TopicAlias uint16 `json:"topicAlias"`
PayloadFormat byte `json:"payloadFormat"`
PayloadFormatFlag bool `json:"payloadFormatFlag"`
}
// MarshalBinary encodes the values into a json string.
func (d Message) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Message) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`
ID string `json:"id" storm:"id"`
Client string `json:"client"`
Filter string `json:"filter"`
Identifier int `json:"identifier"`
RetainHandling byte `json:"retain_handling"`
Qos byte `json:"qos"`
RetainAsPublished bool `json:"retain_as_pub"`
NoLocal bool `json:"no_local"`
}
// MarshalBinary encodes the values into a json string.
func (d Subscription) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Subscription) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// SystemInfo is a storable representation of the system information values.
type SystemInfo struct {
system.Info // embed the system info struct
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
}
// MarshalBinary encodes the values into a json string.
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}

View File

@@ -0,0 +1,195 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
var (
clientStruct = Client{
ID: "test",
T: "client",
Remote: "remote",
Listener: "listener",
Username: []byte("mochi"),
Clean: true,
Properties: ClientProperties{
SessionExpiryInterval: 2,
SessionExpiryIntervalFlag: true,
AuthenticationMethod: "a",
AuthenticationData: []byte("test"),
RequestProblemInfo: 1,
RequestProblemInfoFlag: true,
RequestResponseInfo: 1,
ReceiveMaximum: 128,
TopicAliasMaximum: 256,
User: []packets.UserProperty{
{Key: "k", Val: "v"},
},
MaximumPacketSize: 120,
},
Will: ClientWill{
Qos: 1,
Payload: []byte("abc"),
TopicName: "a/b/c",
Flag: 1,
Retain: true,
WillDelayInterval: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
}
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
messageStruct = Message{
T: "message",
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: 2,
Type: 3,
Qos: 1,
Dup: true,
Retain: true,
},
ID: "id",
Origin: "mochi",
TopicName: "topic",
Properties: MessageProperties{
PayloadFormat: 1,
PayloadFormatFlag: true,
MessageExpiryInterval: 20,
ContentType: "type",
ResponseTopic: "a/b/r",
CorrelationData: []byte("r"),
SubscriptionIdentifier: []int{1},
TopicAlias: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
PacketID: 100,
}
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
subscriptionStruct = Subscription{
T: "subscription",
ID: "id",
Client: "mochi",
Filter: "a/b/c",
Qos: 1,
}
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`)
sysInfoStruct = SystemInfo{
T: "info",
ID: "id",
Info: system.Info{
Version: "2.0.0",
Started: 1,
Uptime: 2,
BytesReceived: 3,
BytesSent: 4,
ClientsConnected: 5,
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
Inflight: 16,
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
data, err := clientStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, clientJSON, data)
}
func TestClientUnmarshalBinary(t *testing.T) {
d := clientStruct
err := d.UnmarshalBinary(clientJSON)
require.NoError(t, err)
require.Equal(t, clientStruct, d)
}
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
d := Client{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Client{}, d)
}
func TestMessageMarshalBinary(t *testing.T) {
data, err := messageStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, messageJSON, data)
}
func TestMessageUnmarshalBinary(t *testing.T) {
d := messageStruct
err := d.UnmarshalBinary(messageJSON)
require.NoError(t, err)
require.Equal(t, messageStruct, d)
}
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
d := Message{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Message{}, d)
}
func TestSubscriptionMarshalBinary(t *testing.T) {
data, err := subscriptionStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, subscriptionJSON, data)
}
func TestSubscriptionUnmarshalBinary(t *testing.T) {
d := subscriptionStruct
err := d.UnmarshalBinary(subscriptionJSON)
require.NoError(t, err)
require.Equal(t, subscriptionStruct, d)
}
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
d := Subscription{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Subscription{}, d)
}
func TestSysInfoMarshalBinary(t *testing.T) {
data, err := sysInfoStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, sysInfoJSON, data)
}
func TestSysInfoUnmarshalBinary(t *testing.T) {
d := sysInfoStruct
err := d.UnmarshalBinary(sysInfoJSON)
require.NoError(t, err)
require.Equal(t, sysInfoStruct, d)
}
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
d := SystemInfo{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}

622
hooks_test.go Normal file
View File

@@ -0,0 +1,622 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
type modifiedHookBase struct {
HookBase
err error
fail bool
failAt int
}
var errTestHook = errors.New("error")
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) Provides(b byte) bool {
return true
}
func (h *modifiedHookBase) Stop() error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return true
}
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
if h.fail {
return will, errTestHook
}
return will, nil
}
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
if h.fail || h.failAt == 1 {
return v, errTestHook
}
return []storage.Client{
{ID: "cl1"},
{ID: "cl2"},
{ID: "cl3"},
}, nil
}
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.fail || h.failAt == 2 {
return v, errTestHook
}
return []storage.Subscription{
{ID: "sub1"},
{ID: "sub2"},
{ID: "sub3"},
}, nil
}
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 3 {
return v, errTestHook
}
return []storage.Message{
{ID: "r1"},
{ID: "r2"},
{ID: "r3"},
}, nil
}
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 4 {
return v, errTestHook
}
return []storage.Message{
{ID: "i1"},
{ID: "i2"},
{ID: "i3"},
}, nil
}
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.fail || h.failAt == 5 {
return v, errTestHook
}
return storage.SystemInfo{
Info: system.Info{
Version: "2.0.0",
},
}, nil
}
type providesCheckHook struct {
HookBase
}
func (h *providesCheckHook) Provides(b byte) bool {
return b == OnConnect
}
func TestHooksProvides(t *testing.T) {
h := new(Hooks)
err := h.Add(new(providesCheckHook), nil)
require.NoError(t, err)
err = h.Add(new(HookBase), nil)
require.NoError(t, err)
require.True(t, h.Provides(OnConnect, OnDisconnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
}
func TestHooksAddInitFailure(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), map[string]any{})
require.Error(t, err)
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
}
func TestHooksStop(t *testing.T) {
h := new(Hooks)
h.Log = &logger
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
h.Stop()
}
// coverage: also cover some empty functions
func TestHooksNonReturns(t *testing.T) {
h := new(Hooks)
cl := new(Client)
for i := 0; i < 2; i++ {
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
// on first iteration, check without hook methods
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnConnect(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
h.OnPacketProcessed(cl, packets.Packet{}, nil)
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
})
}
}
func TestHooksOnConnectAuthenticate(t *testing.T) {
h := new(Hooks)
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.True(t, ok)
}
func TestHooksOnACLCheck(t *testing.T) {
h := new(Hooks)
ok := h.OnACLCheck(new(Client), "a/b/c", true)
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnACLCheck(new(Client), "a/b/c", true)
require.True(t, ok)
}
func TestHooksOnSubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnSubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnSelectSubscribers(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
subs := &Subscribers{
Subscriptions: map[string]packets.Subscription{
"cl1": {Filter: "a/b/c"},
},
}
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
require.EqualValues(t, subs, subs2)
}
func TestHooksOnUnsubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnUnsubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnPublish(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketRead(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnAuthPacket(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
hook.fail = true
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnLWT(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
// coverage: fail lwt
hook.fail = true
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHooksStoredClients(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredClients()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSubscriptions(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredSubscriptions()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredRetainedMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredRetainedMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredInflightMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredInflightMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSysInfo(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Info.Version)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", v.Info.Version)
hook.fail = true
v, err = h.StoredSysInfo()
require.Error(t, err)
require.Equal(t, "", v.Info.Version)
}
func TestHookBaseID(t *testing.T) {
h := new(HookBase)
require.Equal(t, "base", h.ID())
}
func TestHookBaseProvidesNone(t *testing.T) {
h := new(HookBase)
require.False(t, h.Provides(OnConnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHookBaseInit(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Init(nil))
}
func TestHookBaseSetOpts(t *testing.T) {
h := new(HookBase)
h.SetOpts(&logger, new(HookOptions))
require.NotNil(t, h.Log)
require.NotNil(t, h.Opts)
}
func TestHookBaseClose(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Stop())
}
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
h := new(HookBase)
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnPacketRead(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnAuthPacket(t *testing.T) {
h := new(HookBase)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnLWT(t *testing.T) {
h := new(HookBase)
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.NoError(t, err)
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHookBaseStoredClients(t *testing.T) {
h := new(HookBase)
v, err := h.StoredClients()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredSubscriptions(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredInflightMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredRetainedMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoreSysInfo(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Version)
}

144
inflight.go Normal file
View File

@@ -0,0 +1,144 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sort"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/packets"
)
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]packets.Packet // internal contains the inflight packets
receiveQuota int32 // remaining inbound qos quota for flow control
sendQuota int32 // remaining outbound qos quota for flow control
maximumReceiveQuota int32 // maximum allowed receive quota
maximumSendQuota int32 // maximum allowed send quota
}
// NewInflights returns a new instance of an Inflight packets map.
func NewInflights() *Inflight {
return &Inflight{
internal: map[uint16]packets.Packet{},
}
}
// Set adds or updates an inflight packet by packet id.
func (i *Inflight) Set(m packets.Packet) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[m.PacketID]
i.internal[m.PacketID] = m
return !ok
}
// Get returns an inflight packet by packet id.
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
if m, ok := i.internal[id]; ok {
return m, true
}
return packets.Packet{}, false
}
// Len returns the size of the inflight messages map.
func (i *Inflight) Len() int {
i.RLock()
defer i.RUnlock()
return len(i.internal)
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
defer i.RUnlock()
m := []packets.Packet{}
for _, v := range i.internal {
if !immediate || (immediate && v.Expiry < 0) {
m = append(m, v)
}
}
sort.Slice(m, func(i, j int) bool {
return uint16(m[i].Created) < uint16(m[j].Created)
})
return m
}
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
// is free to continue sending.
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
m := i.GetAll(true)
if len(m) > 0 {
return m[0], true
}
return packets.Packet{}, false
}
// Delete removes an in-flight message from the map. Returns true if the message existed.
func (i *Inflight) Delete(id uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[id]
delete(i.internal, id)
return ok
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) TakeReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) ReturnReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
}
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.receiveQuota, n)
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// TakeSendQuota reduces the send quota by 1.
func (i *Inflight) TakeSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// ReturnSendQuota increases the send quota by 1.
func (i *Inflight) ReturnSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}
}
// ResetSendQuota resets the send quota to the maximum allowed value.
func (i *Inflight) ResetSendQuota(n int32) {
atomic.StoreInt32(&i.sendQuota, n)
atomic.StoreInt32(&i.maximumSendQuota, n)
}

189
inflight_test.go Normal file
View File

@@ -0,0 +1,189 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sync/atomic"
"testing"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
require.NotNil(t, cl.State.Inflight.internal[1])
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.False(t, r)
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
require.True(t, ok)
require.NotEqual(t, 0, msg.PacketID)
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
require.Equal(t, []packets.Packet{
{PacketID: 1, Created: 1},
{PacketID: 2, Created: 2},
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
{PacketID: 5, Created: 5},
}, cl.State.Inflight.GetAll(false))
require.Equal(t, []packets.Packet{
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
}, cl.State.Inflight.GetAll(true))
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
r := cl.State.Inflight.Delete(3)
require.True(t, r)
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
_, ok := cl.State.Inflight.Get(3)
require.False(t, ok)
r = cl.State.Inflight.Delete(3)
require.False(t, r)
}
func TestResetReceiveQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
i.ResetReceiveQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
}
func TestReceiveQuota(t *testing.T) {
i := NewInflights()
i.receiveQuota = 4
i.maximumReceiveQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.ReturnReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.ReturnReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Reset to max 1
i.ResetReceiveQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.TakeReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.TakeReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
func TestResetSendQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
i.ResetSendQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
}
func TestSendQuota(t *testing.T) {
i := NewInflights()
i.sendQuota = 4
i.maximumSendQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.ReturnSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.ReturnSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Reset to max 1
i.ResetSendQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.TakeSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.TakeSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
pk, ok := cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
r := cl.State.Inflight.Delete(3)
require.True(t, r)
pk, ok = cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
r = cl.State.Inflight.Delete(4)
require.True(t, r)
_, ok = cl.State.Inflight.NextImmediate()
require.False(t, ok)
}

139
listeners/http_sysinfo.go Normal file
View File

@@ -0,0 +1,139 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
log *zerolog.Logger // server logger
sysInfo *system.Info // pointers to the server data
end uint32 // ensure the close methods are only called once
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats {
if config == nil {
config = new(Config)
}
return &HTTPStats{
id: id,
address: address,
sysInfo: sysInfo,
config: config,
}
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *HTTPStats) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *HTTPStats) Protocol() string {
if l.listen != nil && l.listen.TLSConfig != nil {
return "https"
}
return "http"
}
// Init initializes the listener.
func (l *HTTPStats) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
Addr: l.address,
Handler: mux,
}
if l.config.TLSConfig != nil {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := &system.Info{
Version: l.sysInfo.Version,
Started: atomic.LoadInt64(&l.sysInfo.Started),
Time: atomic.LoadInt64(&l.sysInfo.Time),
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
}
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
}
w.Write(out)
}

View File

@@ -0,0 +1,127 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.ID())
}
func TestHTTPStatsAddress(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, testAddr, l.Address())
}
func TestHTTPStatsProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "http", l.Protocol())
}
func TestHTTPStatsTLSProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, nil)
l.Init(nil)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.sysInfo)
require.Equal(t, sysInfo, l.sysInfo)
require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
// setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// get body from stats address
resp, err := http.Get("http://localhost" + testAddr)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// decode body from json and check data
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testAddr)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.Close(MockCloser)
}

View File

@@ -1,94 +1,91 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
"github.com/rs/zerolog"
)
// Config contains configuration values for a listener.
type Config struct {
Auth auth.Controller // an authentication controller containing auth and ACL logic.
TLS *TLS // the TLS certficates and settings for the connection.
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// TLS contains the TLS certificates and settings for the listener connection.
type TLS struct {
Certificate []byte // the body of a public certificate.
PrivateKey []byte // the body of a private key.
}
// EstablishFunc is a callback function for establishing new clients.
type EstablishFunc func(id string, c net.Conn, ac auth.Controller) error
// EstablishFn is a callback function for establishing new clients.
type EstablishFn func(id string, c net.Conn) error
// CloseFunc is a callback function for closing all listener clients.
type CloseFunc func(id string)
type CloseFn func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
SetConfig(*Config) // set the listener config.
Listen(s *system.Info) error // open the network address.
Serve(EstablishFunc) // starting actively listening for new connections.
ID() string // return the id of the listener.
Close(CloseFunc) // stop and close the listener.
Init(*zerolog.Logger) error // open the network address
Serve(EstablishFn) // starting actively listening for new connections
ID() string // return the id of the listener
Address() string // the address of the listener
Protocol() string // the protocol in use by the listener
Close(CloseFn) // stop and close the listener
}
// Listeners contains the network listeners for the broker.
type Listeners struct {
sync.RWMutex
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.
func New(s *system.Info) *Listeners {
func New() *Listeners {
return &Listeners{
internal: map[string]Listener{},
system: s,
}
}
// Add adds a new listener to the listeners map, keyed on id.
func (l *Listeners) Add(val Listener) {
l.Lock()
defer l.Unlock()
l.internal[val.ID()] = val
l.Unlock()
}
// Get returns the value of a listener if it exists.
func (l *Listeners) Get(id string) (Listener, bool) {
l.RLock()
defer l.RUnlock()
val, ok := l.internal[id]
l.RUnlock()
return val, ok
}
// Len returns the length of the listeners map.
func (l *Listeners) Len() int {
l.RLock()
val := len(l.internal)
l.RUnlock()
return val
defer l.RUnlock()
return len(l.internal)
}
// Delete removes a listener from the internal map.
func (l *Listeners) Delete(id string) {
l.Lock()
defer l.Unlock()
delete(l.internal, id)
l.Unlock()
}
// Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFunc) {
func (l *Listeners) Serve(id string, establisher EstablishFn) {
l.RLock()
defer l.RUnlock()
listener := l.internal[id]
l.RUnlock()
go func(e EstablishFunc) {
go func(e EstablishFn) {
defer l.wg.Done()
l.wg.Add(1)
listener.Serve(e)
@@ -96,7 +93,7 @@ func (l *Listeners) Serve(id string, establisher EstablishFunc) {
}
// ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFunc) {
func (l *Listeners) ServeAll(establisher EstablishFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
@@ -112,15 +109,16 @@ func (l *Listeners) ServeAll(establisher EstablishFunc) {
}
// Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFunc) {
func (l *Listeners) Close(id string, closer CloseFn) {
l.RLock()
listener := l.internal[id]
l.RUnlock()
listener.Close(closer)
defer l.RUnlock()
if listener, ok := l.internal[id]; ok {
listener.Close(closer)
}
}
// CloseAll iterates and closes all registered listeners.
func (l *Listeners) CloseAll(closer CloseFunc) {
func (l *Listeners) CloseAll(closer CloseFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))

177
listeners/listeners_test.go Normal file
View File

@@ -0,0 +1,177 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"log"
"os"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
const testAddr = ":22222"
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New()
require.NotNil(t, l.internal)
}
func TestAddListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
}
func TestGetListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.True(t, ok)
require.Equal(t, g.ID(), "t1")
}
func TestLenListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func TestDeleteListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.False(t, ok)
require.Nil(t, l.internal["t1"])
}
func TestServeListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
}
func TestServeAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
require.False(t, l.internal["t2"].(*MockListener).IsServing())
require.False(t, l.internal["t3"].(*MockListener).IsServing())
}
func TestCloseListener(t *testing.T) {
l := New()
mocked := NewMockListener("t1", testAddr)
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.True(t, closed)
}
func TestCloseAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.True(t, closed["t1"])
require.True(t, closed["t2"])
require.True(t, closed["t3"])
}

103
listeners/mock.go Normal file
View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"fmt"
"net"
"sync"
"github.com/rs/zerolog"
)
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn) error {
return nil
}
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener
address string // the network address the listener binds to
Config *Config // configuration for the listener
done chan bool // indicate the listener is done
Serving bool // indicate the listener is serving
Listening bool // indiciate the listener is listening
ErrListen bool // throw an error on listen
}
// NewMockListener returns a new instance of MockListener.
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFn) {
l.Lock()
l.Serving = true
l.Unlock()
for range l.done {
return
}
}
// Init initializes the listener.
func (l *MockListener) Init(log *zerolog.Logger) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
defer l.Unlock()
l.Listening = true
return nil
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *MockListener) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *MockListener) Protocol() string {
return "mock"
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFn) {
l.Lock()
defer l.Unlock()
l.Serving = false
closer(l.id)
close(l.done)
}
// IsServing indicates whether the mock listener is serving.
func (l *MockListener) IsServing() bool {
l.Lock()
defer l.Unlock()
return l.Serving
}
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()
return l.Listening
}

99
listeners/mock_test.go Normal file
View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w)
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
}
func TestMockListenerID(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.ID())
}
func TestMockListenerAddress(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, testAddr, mocked.Address())
}
func TestMockListenerProtocol(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "mock", mocked.Protocol())
}
func TestNewMockListenerIsListening(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsListening())
}
func TestNewMockListenerIsServing(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
}
func TestNewMockListenerInit(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
require.Equal(t, false, mocked.IsListening())
err := mocked.Init(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening())
}
func TestNewMockListenerInitFailure(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
mocked.ErrListen = true
err := mocked.Init(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing())
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Init(nil)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}

108
listeners/tcp.go Normal file
View File

@@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
func NewTCP(id, address string, config *Config) *TCP {
if config == nil {
config = new(Config)
}
return &TCP{
id: id,
address: address,
config: config,
}
}
// ID returns the id of the listener.
func (l *TCP) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *TCP) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *TCP) Protocol() string {
return "tcp"
}
// Init initializes the listener.
func (l *TCP) Init(log *zerolog.Logger) error {
l.log = log
var err error
if l.config.TLSConfig != nil {
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen("tcp", l.address)
}
return err
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *TCP) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

131
listeners/tcp_test.go Normal file
View File

@@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestTCPID(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestTCPAddress(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestTCPProtocol(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPProtocolTLS(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
l.Init(&logger)
defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
require.NotNil(t, l2.config.TLSConfig)
}
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

167
listeners/websocket.go Normal file
View File

@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // an http server for serving websocket connections
log *zerolog.Logger // server logger
establish EstablishFn // the server's establish connection handler
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
end uint32 // ensure the close methods are only called once
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string, config *Config) *Websocket {
if config == nil {
config = new(Config)
}
return &Websocket{
id: id,
address: address,
config: config,
upgrader: &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Websocket) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *Websocket) Protocol() string {
if l.config.TLSConfig != nil {
return "wss"
}
return "ws"
}
// Init initializes the listener.
func (l *Websocket) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
TLSConfig: l.config.TLSConfig,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
}
return nil
}
// handler upgrades and handles an incoming websocket connection.
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
if err != nil {
l.log.Warn().Err(err).Send()
}
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFn) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
return r.Read(p)
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}

114
listeners/websocket_test.go Normal file
View File

@@ -0,0 +1,114 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestWebsocketAddress(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestWebsocketProtocol(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "ws", l.Protocol())
}
func TestWebsocketProtocoTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol())
}
func TestWebsockeInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Nil(t, l.listen)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
}
func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}

View File

@@ -1,12 +1,18 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"encoding/binary"
"io"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc, no-copy byte to string conversion.
// bytesToString provides a zero-alloc no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
@@ -15,12 +21,21 @@ func bytesToString(bs []byte) string {
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrOffsetUintOutOfRange
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeUint32 extracts the value of four bytes from a byte array.
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
if len(buf) < offset+4 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
@@ -28,22 +43,27 @@ func decodeString(buf []byte, offset int) (string, int, error) {
return "", 0, err
}
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
return "", 0, ErrMalformedInvalidUTF8
}
return bytesToString(b), n, nil
}
// validUTF8 checks if the byte array contains valid UTF-8 characters.
func validUTF8(b []byte) bool {
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0, 0), 0, err
return make([]byte, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange
}
if !validUTF8(buf[next : next+int(length)]) {
return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
}
return buf[next : next+int(length)], next + int(length), nil
@@ -52,7 +72,7 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrOffsetByteOutOfRange
return 0, 0, ErrMalformedOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
@@ -60,7 +80,7 @@ func decodeByte(buf []byte, offset int) (byte, int, error) {
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrOffsetBoolOutOfRange
return false, 0, ErrMalformedOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
@@ -75,7 +95,7 @@ func encodeBool(b bool) byte {
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In many circumstances the number of bytes being encoded is small.
// In most circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
@@ -90,6 +110,13 @@ func encodeUint16(val uint16) []byte {
return buf
}
// encodeUint32 encodes a uint16 value to a byte array.
func encodeUint32(val uint32) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
@@ -99,16 +126,47 @@ func encodeString(val string) []byte {
return append(buf, []byte(val)...)
}
// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically
// conforming to the MQTT specification requirements.
func validUTF8(b []byte) bool {
// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8...
if !utf8.Valid(b) {
return false
// encodeLength writes length bits for the header.
func encodeLength(b *bytes.Buffer, length int64) {
// 1.5.5 Variable Byte Integer encode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
for {
eb := byte(length % 128)
length /= 128
if length > 0 {
eb |= 0x80
}
b.WriteByte(eb)
if length == 0 {
break // [MQTT-1.5.5-1]
}
}
}
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
// see 1.5.5 Variable Byte Integer decode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
var multiplier uint32
var value uint32
bu = 1
for {
eb, err := b.ReadByte()
if err != nil {
return 0, bu, err
}
value |= uint32(eb&127) << multiplier
if value > 268435455 {
return 0, bu, ErrMalformedVariableByteInteger
}
if (eb & 128) == 0 {
break
}
multiplier += 7
bu++
}
// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000...
// ...
return true
return int(value), bu, nil
}

422
packets/codec_test.go Normal file
View File

@@ -0,0 +1,422 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
Connect << 4, 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrMalformedInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
require.NoError(t, err)
require.Equal(t, "\ufeff", result)
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrMalformedOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func TestDecodeUint32(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint32
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 0, 0, 7, 8},
result: uint32(7),
offset: 0,
},
{
rawBytes: []byte{0, 0, 1, 226, 64, 8},
result: uint32(123456),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+4, offset)
})
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrMalformedOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func TestDecodeLength(t *testing.T) {
b := bytes.NewBuffer([]byte{0x78})
n, bu, err := DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 120, n)
require.Equal(t, 1, bu)
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
n, bu, err = DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 268435455, n)
require.Equal(t, 4, bu)
}
func TestDecodeLengthErrors(t *testing.T) {
b := bytes.NewBuffer([]byte{})
_, _, err := DecodeLength(b)
require.Error(t, err)
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
_, _, err = DecodeLength(b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result)
result = encodeBool(false)
require.Equal(t, byte(0), result)
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result)
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result)
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(65535)
require.Equal(t, []byte{0xff, 0xff}, result)
}
func TestEncodeUint32(t *testing.T) {
result := encodeUint32(7)
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
result = encodeUint32(32767)
require.Equal(t, []byte{0, 0, 127, 255}, result)
result = encodeUint32(math.MaxUint32)
require.Equal(t, []byte{255, 255, 255, 255}, result)
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result)
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result)
}
func TestEncodeLength(t *testing.T) {
b := new(bytes.Buffer)
encodeLength(b, 120)
require.Equal(t, []byte{0x78}, b.Bytes())
b = new(bytes.Buffer)
encodeLength(b, math.MaxInt64)
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
}
func TestValidUTF8(t *testing.T) {
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
require.False(t, validUTF8([]byte{0xff, 0xff}))
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
}

127
packets/codes.go Normal file
View File

@@ -0,0 +1,127 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
type Code struct {
Reason string
Code byte
}
// String returns the readable reason for a code.
func (c Code) String() string {
return c.Reason
}
// Error returns the readable reason for a code.
func (c Code) Error() string {
return c.Reason
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
2: CodeGrantedQos2,
}
CodeSuccess = Code{Code: 0x00, Reason: "success"}
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
)

29
packets/codes_test.go Normal file
View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCodesString(t *testing.T) {
c := Code{
Reason: "test",
Code: 0x1,
}
require.Equal(t, "test", c.String())
}
func TestCodesErrorr(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,
}
require.Equal(t, "error", error(c).Error())
}

63
packets/fixedheader.go Normal file
View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte `json:"qos"` // indicates the quality of service expected.
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
Retain bool `json:"retain"` // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(hb byte) error {
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
}
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
fh.Qos = (hb >> 1) & 0x03 // qos flag
fh.Retain = hb&0x01 > 0 // is retain flag
case Pubrel:
fallthrough
case Subscribe:
fallthrough
case Unsubscribe:
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
return ErrMalformedFlags
}
fh.Qos = (hb >> 1) & 0x03
default:
if (hb>>0)&0x01 != 0 ||
(hb>>1)&0x01 != 0 ||
(hb>>2)&0x01 != 0 ||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
return ErrMalformedFlags
}
}
if fh.Qos == 0 && fh.Dup {
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
}
return nil
}

237
packets/fixedheader_test.go Normal file
View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
desc string
rawBytes []byte
header FixedHeader
packetError bool
expect error
}
var fixedHeaderExpected = []fixedHeaderTable{
{
desc: "connect",
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "connack",
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish",
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1",
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish qos 2",
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
desc: "publish qos 2 retain",
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 0",
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 0 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 2 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "puback",
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrec",
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrel",
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "pubcomp",
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "subscribe",
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "suback",
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "unsubscribe",
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "unsuback",
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingreq",
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingresp",
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "disconnect",
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "auth",
rawBytes: []byte{Auth << 4, 0x00},
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
desc: "remaining length 10",
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
desc: "remaining length 512",
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
desc: "remaining length 978",
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
desc: "remaining length 20202",
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
desc: "remaining length oversize",
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
desc: "invalid type dup is true",
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type qos is 1",
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type retain is true",
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid publish qos bits 1 + 2 set",
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
header: FixedHeader{Type: Publish},
expect: ErrProtocolViolationQosOutOfRange,
},
{
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
header: FixedHeader{Type: Pubrel, Qos: 1},
expect: ErrMalformedFlags,
},
{
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
header: FixedHeader{Type: Subscribe, Qos: 1},
expect: ErrMalformedFlags,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.expect == nil {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
}
})
}
}
func TestFixedHeaderDecode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.expect != nil {
require.Equal(t, wanted.expect, err)
} else {
require.NoError(t, err)
require.Equal(t, wanted.header.Type, fh.Type)
require.Equal(t, wanted.header.Dup, fh.Dup)
require.Equal(t, wanted.header.Qos, fh.Qos)
require.Equal(t, wanted.header.Retain, fh.Retain)
}
})
}
}

1141
packets/packets.go Normal file

File diff suppressed because it is too large Load Diff

502
packets/packets_test.go Normal file
View File

@@ -0,0 +1,502 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"testing"
"github.com/jinzhu/copier"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var packetList = []byte{
Connect,
Connack,
Publish,
Puback,
Pubrec,
Pubrel,
Pubcomp,
Subscribe,
Suback,
Unsubscribe,
Unsuback,
Pingreq,
Pingresp,
Disconnect,
Auth,
}
var pkTable = []TPacketCase{
TPacketData[Connect].Get(TConnectMqtt311),
TPacketData[Connect].Get(TConnectMqtt5),
TPacketData[Connect].Get(TConnectUserPassLWT),
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
TPacketData[Connack].Get(TConnackAcceptedNoSession),
TPacketData[Publish].Get(TPublishBasic),
TPacketData[Publish].Get(TPublishMqtt5),
TPacketData[Puback].Get(TPuback),
TPacketData[Pubrec].Get(TPubrec),
TPacketData[Pubrel].Get(TPubrel),
TPacketData[Pubcomp].Get(TPubcomp),
TPacketData[Subscribe].Get(TSubscribe),
TPacketData[Subscribe].Get(TSubscribeMqtt5),
TPacketData[Suback].Get(TSuback),
TPacketData[Unsubscribe].Get(TUnsubscribe),
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
TPacketData[Pingreq].Get(TPingreq),
TPacketData[Pingresp].Get(TPingresp),
TPacketData[Disconnect].Get(TDisconnect),
TPacketData[Disconnect].Get(TDisconnectMqtt5),
}
func TestNewPackets(t *testing.T) {
s := NewPackets()
require.NotNil(t, s.internal)
}
func TestPacketsAdd(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{})
require.Contains(t, s.internal, "cl1")
}
func TestPacketsGet(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
pk, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, "a1", pk.TopicName)
}
func TestPacketsGetAll(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
s.Add("cl3", Packet{TopicName: "a3"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestPacketsLen(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSPacketsDelete(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
p := &Subscription{
Qos: 2,
NoLocal: true,
RetainAsPublished: true,
RetainHandling: 2,
}
x := new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
p = &Subscription{
Qos: 1,
NoLocal: false,
RetainAsPublished: false,
RetainHandling: 1,
}
x = new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
}
func TestPacketEncode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !encodeTestOK(wanted) {
return
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
buf := new(bytes.Buffer)
var err error
switch pkt {
case Connect:
err = pk.ConnectEncode(buf)
case Connack:
err = pk.ConnackEncode(buf)
case Publish:
err = pk.PublishEncode(buf)
case Puback:
err = pk.PubackEncode(buf)
case Pubrec:
err = pk.PubrecEncode(buf)
case Pubrel:
err = pk.PubrelEncode(buf)
case Pubcomp:
err = pk.PubcompEncode(buf)
case Subscribe:
err = pk.SubscribeEncode(buf)
case Suback:
err = pk.SubackEncode(buf)
case Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case Unsuback:
err = pk.UnsubackEncode(buf)
case Pingreq:
err = pk.PingreqEncode(buf)
case Pingresp:
err = pk.PingrespEncode(buf)
case Disconnect:
err = pk.DisconnectEncode(buf)
case Auth:
err = pk.AuthEncode(buf)
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
encoded := buf.Bytes()
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
if len(wanted.ActualBytes) > 0 {
wanted.RawBytes = wanted.ActualBytes
}
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
})
}
}
}
func TestPacketDecode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !decodeTestOK(wanted) {
return
}
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
}
buf := wanted.RawBytes[2:]
var err error
switch pkt {
case Connect:
err = pk.ConnectDecode(buf)
case Connack:
err = pk.ConnackDecode(buf)
case Publish:
err = pk.PublishDecode(buf)
case Puback:
err = pk.PubackDecode(buf)
case Pubrec:
err = pk.PubrecDecode(buf)
case Pubrel:
err = pk.PubrelDecode(buf)
case Pubcomp:
err = pk.PubcompDecode(buf)
case Subscribe:
err = pk.SubscribeDecode(buf)
case Suback:
err = pk.SubackDecode(buf)
case Unsubscribe:
err = pk.UnsubscribeDecode(buf)
case Unsuback:
err = pk.UnsubackDecode(buf)
case Pingreq:
err = pk.PingreqDecode(buf)
case Pingresp:
err = pk.PingrespDecode(buf)
case Disconnect:
err = pk.DisconnectDecode(buf)
case Auth:
err = pk.AuthDecode(buf)
}
if wanted.FailFirst != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
if pkt == Connect {
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
// against it unless it's a connect packet.
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
}
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
})
}
}
}
func TestValidate(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if wanted.Group == "validate" || wanted.Primary {
pk := wanted.Packet
var err error
switch pkt {
case Connect:
err = pk.ConnectValidate()
case Publish:
err = pk.PublishValidate(1024)
case Subscribe:
err = pk.SubscribeValidate()
case Unsubscribe:
err = pk.UnsubscribeValidate()
case Auth:
err = pk.AuthValidate()
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
}
}
})
}
}
}
func TestAckValidatePubrec(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoMatchingSubscribers.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicNameInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrPayloadFormatInvalid.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubrel(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubcomp(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateSuback(t *testing.T) {
for _, b := range []byte{
CodeGrantedQos0.Code,
CodeGrantedQos1.Code,
CodeGrantedQos2.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrSharedSubscriptionsNotSupported.Code,
ErrSubscriptionIdentifiersNotSupported.Code,
ErrWildcardSubscriptionsNotSupported.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateUnsuback(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoSubscriptionExisted.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestReasonCodeValidMisc(t *testing.T) {
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
require.True(t, pk.ReasonCodeValid())
}
func TestCopy(t *testing.T) {
for _, tt := range pkTable {
pkc := tt.Packet.Copy(true)
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
}
}
func TestMergeSubscription(t *testing.T) {
sub := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 0,
RetainAsPublished: false,
NoLocal: false,
Identifier: 1,
}
sub2 := Subscription{
Filter: "a/b/d",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 2,
}
expect := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 1,
Identifiers: map[string]int{
"a/b/c": 1,
"a/b/d": 2,
},
}
require.Equal(t, expect, sub.Merge(sub2))
}

478
packets/properties.go Normal file
View File

@@ -0,0 +1,478 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"strings"
)
const (
PropPayloadFormat byte = 1
PropMessageExpiryInterval byte = 2
PropContentType byte = 3
PropResponseTopic byte = 8
PropCorrelationData byte = 9
PropSubscriptionIdentifier byte = 11
PropSessionExpiryInterval byte = 17
PropAssignedClientID byte = 18
PropServerKeepAlive byte = 19
PropAuthenticationMethod byte = 21
PropAuthenticationData byte = 22
PropRequestProblemInfo byte = 23
PropWillDelayInterval byte = 24
PropRequestResponseInfo byte = 25
PropResponseInfo byte = 26
PropServerReference byte = 28
PropReasonString byte = 31
PropReceiveMaximum byte = 33
PropTopicAliasMaximum byte = 34
PropTopicAlias byte = 35
PropMaximumQos byte = 36
PropRetainAvailable byte = 37
PropUser byte = 38
PropMaximumPacketSize byte = 39
PropWildcardSubAvailable byte = 40
PropSubIDAvailable byte = 41
PropSharedSubAvailable byte = 42
)
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
PropServerKeepAlive: {Connack: 1},
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropReceiveMaximum: {Connect: 1, Connack: 1},
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
PropSharedSubAvailable: {Connack: 1},
}
// UserProperty is an arbitrary key-value pair for a packet user properties array.
type UserProperty struct { // [MQTT-1.5.7-1]
Key string `json:"k"`
Val string `json:"v"`
}
// Properties contains all of the mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
type Properties struct {
CorrelationData []byte `json:"cd"`
SubscriptionIdentifier []int `json:"si"`
AuthenticationData []byte `json:"ad"`
User []UserProperty `json:"user"`
ContentType string `json:"ct"`
ResponseTopic string `json:"rt"`
AssignedClientID string `json:"aci"`
AuthenticationMethod string `json:"am"`
ResponseInfo string `json:"ri"`
ServerReference string `json:"sr"`
ReasonString string `json:"rs"`
MessageExpiryInterval uint32 `json:"me"`
SessionExpiryInterval uint32 `json:"sei"`
WillDelayInterval uint32 `json:"wdi"`
MaximumPacketSize uint32 `json:"mps"`
ServerKeepAlive uint16 `json:"ska"`
ReceiveMaximum uint16 `json:"rm"`
TopicAliasMaximum uint16 `json:"tam"`
TopicAlias uint16 `json:"ta"`
PayloadFormat byte `json:"pf"`
PayloadFormatFlag bool `json:"fpf"`
SessionExpiryIntervalFlag bool `json:"fsei"`
ServerKeepAliveFlag bool `json:"fska"`
RequestProblemInfo byte `json:"rpi"`
RequestProblemInfoFlag bool `json:"frpi"`
RequestResponseInfo byte `json:"rri"`
TopicAliasFlag bool `json:"fta"`
MaximumQos byte `json:"mqos"`
MaximumQosFlag bool `json:"fmqos"`
RetainAvailable byte `json:"ra"`
RetainAvailableFlag bool `json:"fra"`
WildcardSubAvailable byte `json:"wsa"`
WildcardSubAvailableFlag bool `json:"fwsa"`
SubIDAvailable byte `json:"sida"`
SubIDAvailableFlag bool `json:"fsida"`
SharedSubAvailable byte `json:"ssa"`
SharedSubAvailableFlag bool `json:"fssa"`
}
// Copy creates a new Properties struct with copies of the values.
func (p *Properties) Copy(allowTransfer bool) Properties {
pr := Properties{
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
PayloadFormatFlag: p.PayloadFormatFlag,
MessageExpiryInterval: p.MessageExpiryInterval,
ContentType: p.ContentType, // [MQTT-3.3.2-20]
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
SessionExpiryInterval: p.SessionExpiryInterval,
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
AssignedClientID: p.AssignedClientID,
ServerKeepAlive: p.ServerKeepAlive,
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
AuthenticationMethod: p.AuthenticationMethod,
RequestProblemInfo: p.RequestProblemInfo,
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
WillDelayInterval: p.WillDelayInterval,
RequestResponseInfo: p.RequestResponseInfo,
ResponseInfo: p.ResponseInfo,
ServerReference: p.ServerReference,
ReasonString: p.ReasonString,
ReceiveMaximum: p.ReceiveMaximum,
TopicAliasMaximum: p.TopicAliasMaximum,
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
MaximumQos: p.MaximumQos,
MaximumQosFlag: p.MaximumQosFlag,
RetainAvailable: p.RetainAvailable,
RetainAvailableFlag: p.RetainAvailableFlag,
MaximumPacketSize: p.MaximumPacketSize,
WildcardSubAvailable: p.WildcardSubAvailable,
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
SubIDAvailable: p.SubIDAvailable,
SubIDAvailableFlag: p.SubIDAvailableFlag,
SharedSubAvailable: p.SharedSubAvailable,
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
}
if allowTransfer {
pr.TopicAlias = p.TopicAlias
pr.TopicAliasFlag = p.TopicAliasFlag
}
if len(p.CorrelationData) > 0 {
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
}
if len(p.SubscriptionIdentifier) > 0 {
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
}
if len(p.AuthenticationData) > 0 {
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
}
if len(p.User) > 0 {
pr.User = []UserProperty{}
for _, v := range p.User {
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
Key: v.Key,
Val: v.Val,
})
}
}
return pr
}
// canEncode returns true if the property type is valid for the packet type.
func (p *Properties) canEncode(pkt byte, k byte) bool {
return validPacketProperties[k][pkt] == 1
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
}
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
buf.WriteByte(PropMessageExpiryInterval)
buf.Write(encodeUint32(p.MessageExpiryInterval))
}
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
buf.WriteByte(PropContentType)
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
for _, v := range p.SubscriptionIdentifier {
if v > 0 {
buf.WriteByte(PropSubscriptionIdentifier)
encodeLength(&buf, int64(v))
}
}
}
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
buf.WriteByte(PropSessionExpiryInterval)
buf.Write(encodeUint32(p.SessionExpiryInterval))
}
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
buf.WriteByte(PropAssignedClientID)
buf.Write(encodeString(p.AssignedClientID))
}
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
buf.WriteByte(PropServerKeepAlive)
buf.Write(encodeUint16(p.ServerKeepAlive))
}
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
buf.WriteByte(PropAuthenticationMethod)
buf.Write(encodeString(p.AuthenticationMethod))
}
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
buf.WriteByte(PropAuthenticationData)
buf.Write(encodeBytes(p.AuthenticationData))
}
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
buf.WriteByte(PropRequestProblemInfo)
buf.WriteByte(p.RequestProblemInfo)
}
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
buf.WriteByte(PropWillDelayInterval)
buf.Write(encodeUint32(p.WillDelayInterval))
}
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
buf.WriteByte(PropRequestResponseInfo)
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
buf.WriteByte(PropServerReference)
buf.Write(encodeString(p.ServerReference))
}
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
}
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
buf.WriteByte(PropReceiveMaximum)
buf.Write(encodeUint16(p.ReceiveMaximum))
}
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
buf.WriteByte(PropTopicAliasMaximum)
buf.Write(encodeUint16(p.TopicAliasMaximum))
}
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
buf.WriteByte(PropTopicAlias)
buf.Write(encodeUint16(p.TopicAlias))
}
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
buf.WriteByte(PropMaximumQos)
buf.WriteByte(p.MaximumQos)
}
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
buf.WriteByte(PropRetainAvailable)
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
pb.Write(encodeString(v.Key))
pb.Write(encodeString(v.Val))
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
buf.Write(pb.Bytes())
}
}
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
buf.WriteByte(PropMaximumPacketSize)
buf.Write(encodeUint32(p.MaximumPacketSize))
}
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
buf.WriteByte(PropWildcardSubAvailable)
buf.WriteByte(p.WildcardSubAvailable)
}
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
buf.WriteByte(PropSubIDAvailable)
buf.WriteByte(p.SubIDAvailable)
}
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
buf.WriteByte(PropSharedSubAvailable)
buf.WriteByte(p.SharedSubAvailable)
}
encodeLength(b, int64(buf.Len()))
buf.WriteTo(b) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
n, _, err = DecodeLength(b)
if err != nil {
return n, err
}
if n == 0 {
return n, nil
}
bt := b.Bytes()
var k byte
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
}
switch k {
case PropPayloadFormat:
p.PayloadFormat, offset, err = decodeByte(bt, offset)
p.PayloadFormatFlag = true
case PropMessageExpiryInterval:
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
case PropContentType:
p.ContentType, offset, err = decodeString(bt, offset)
case PropResponseTopic:
p.ResponseTopic, offset, err = decodeString(bt, offset)
case PropCorrelationData:
p.CorrelationData, offset, err = decodeBytes(bt, offset)
case PropSubscriptionIdentifier:
if p.SubscriptionIdentifier == nil {
p.SubscriptionIdentifier = []int{}
}
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
case PropSessionExpiryInterval:
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
p.SessionExpiryIntervalFlag = true
case PropAssignedClientID:
p.AssignedClientID, offset, err = decodeString(bt, offset)
case PropServerKeepAlive:
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
p.ServerKeepAliveFlag = true
case PropAuthenticationMethod:
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
case PropAuthenticationData:
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
case PropRequestProblemInfo:
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
p.RequestProblemInfoFlag = true
case PropWillDelayInterval:
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
case PropRequestResponseInfo:
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
case PropResponseInfo:
p.ResponseInfo, offset, err = decodeString(bt, offset)
case PropServerReference:
p.ServerReference, offset, err = decodeString(bt, offset)
case PropReasonString:
p.ReasonString, offset, err = decodeString(bt, offset)
case PropReceiveMaximum:
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAliasMaximum:
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAlias:
p.TopicAlias, offset, err = decodeUint16(bt, offset)
p.TopicAliasFlag = true
case PropMaximumQos:
p.MaximumQos, offset, err = decodeByte(bt, offset)
p.MaximumQosFlag = true
case PropRetainAvailable:
p.RetainAvailable, offset, err = decodeByte(bt, offset)
p.RetainAvailableFlag = true
case PropUser:
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
case PropMaximumPacketSize:
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
case PropWildcardSubAvailable:
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
p.WildcardSubAvailableFlag = true
case PropSubIDAvailable:
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
p.SubIDAvailableFlag = true
case PropSharedSubAvailable:
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
p.SharedSubAvailableFlag = true
}
if err != nil {
return n, err
}
}
return n, nil
}

333
packets/properties_test.go Normal file
View File

@@ -0,0 +1,333 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var (
propertiesStruct = Properties{
PayloadFormat: byte(1), // UTF-8 Format
PayloadFormatFlag: true,
MessageExpiryInterval: uint32(2),
ContentType: "text/plain",
ResponseTopic: "a/b/c",
CorrelationData: []byte("data"),
SubscriptionIdentifier: []int{322122},
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
AssignedClientID: "mochi-v5",
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AuthenticationMethod: "SHA-1",
AuthenticationData: []byte("auth-data"),
RequestProblemInfo: byte(1),
RequestProblemInfoFlag: true,
WillDelayInterval: uint32(600),
RequestResponseInfo: byte(1),
ResponseInfo: "response",
ServerReference: "mochi-2",
ReasonString: "reason",
ReceiveMaximum: uint16(500),
TopicAliasMaximum: uint16(999),
TopicAlias: uint16(3),
TopicAliasFlag: true,
MaximumQos: byte(1),
MaximumQosFlag: true,
RetainAvailable: byte(1),
RetainAvailableFlag: true,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
{
Key: "key2",
Val: "value2",
},
},
MaximumPacketSize: uint32(32000),
WildcardSubAvailable: byte(1),
WildcardSubAvailableFlag: true,
SubIDAvailable: byte(1),
SubIDAvailableFlag: true,
SharedSubAvailable: byte(1),
SharedSubAvailableFlag: true,
}
propertiesBytes = []byte{
172, 1, // VBI
// Payload Format (1) (vbi:2)
1, 1,
// Message Expiry (2) (vbi:7)
2, 0, 0, 0, 2,
// Content Type (3) (vbi:20)
3,
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
// Response Topic (8) (vbi:28)
8,
0, 5, 'a', '/', 'b', '/', 'c',
// Correlations Data (9) (vbi:35)
9,
0, 4, 'd', 'a', 't', 'a',
// Subscription Identifier (11) (vbi:39)
11,
202, 212, 19,
// Session Expiry Interval (17) (vbi:43)
17,
0, 0, 0, 120,
// Assigned Client ID (18) (vbi:55)
18,
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
// Server Keep Alive (19) (vbi:58)
19,
0, 20,
// Authentication Method (21) (vbi:66)
21,
0, 5, 'S', 'H', 'A', '-', '1',
// Authentication Data (22) (vbi:78)
22,
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
// Request Problem Info (23) (vbi:80)
23, 1,
// Will Delay Interval (24) (vbi:85)
24,
0, 0, 2, 88,
// Request Response Info (25) (vbi:87)
25, 1,
// Response Info (26) (vbi:98)
26,
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
// Server Reference (28) (vbi:108)
28,
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
// Reason String (31) (vbi:117)
31,
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
// Receive Maximum (33) (vbi:120)
33,
1, 244,
// Topic Alias Maximum (34) (vbi:123)
34,
3, 231,
// Topic Alias (35) (vbi:126)
35,
0, 3,
// Maximum Qos (36) (vbi:128)
36, 1,
// Retain Available (37) (vbi: 130)
37, 1,
// User Properties (38) (vbi:161)
38,
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
38,
0, 4, 'k', 'e', 'y', '2',
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
// Maximum Packet Size (39) (vbi:166)
39,
0, 0, 125, 0,
// Wildcard Subscriptions Available (40) (vbi:168)
40, 1,
// Subscription ID Available (41) (vbi:170)
41, 1,
// Shared Subscriptions Available (42) (vbi:172)
42, 1,
}
)
func init() {
validPacketProperties[PropPayloadFormat][Reserved] = 1
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
validPacketProperties[PropContentType][Reserved] = 1
validPacketProperties[PropResponseTopic][Reserved] = 1
validPacketProperties[PropCorrelationData][Reserved] = 1
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
validPacketProperties[PropAssignedClientID][Reserved] = 1
validPacketProperties[PropServerKeepAlive][Reserved] = 1
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
validPacketProperties[PropAuthenticationData][Reserved] = 1
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
validPacketProperties[PropWillDelayInterval][Reserved] = 1
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
validPacketProperties[PropResponseInfo][Reserved] = 1
validPacketProperties[PropServerReference][Reserved] = 1
validPacketProperties[PropReasonString][Reserved] = 1
validPacketProperties[PropReceiveMaximum][Reserved] = 1
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
validPacketProperties[PropTopicAlias][Reserved] = 1
validPacketProperties[PropMaximumQos][Reserved] = 1
validPacketProperties[PropRetainAvailable][Reserved] = 1
validPacketProperties[PropUser][Reserved] = 1
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
validPacketProperties[PropSubIDAvailable][Reserved] = 1
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
}
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
}
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
}
func TestEncodePropertiesNil(t *testing.T) {
type tmp struct {
p *Properties
}
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
func TestDecodeProperties(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172, n)
require.EqualValues(t, propertiesStruct, *props)
}
func TestDecodePropertiesNil(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
type tmp struct {
p *Properties
}
pr := tmp{}
n, err := pr.p.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 0, n)
}
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
}
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{0})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, props, new(Properties))
}
func TestDecodePropertiesBadKeyByte(t *testing.T) {
b := bytes.NewBuffer([]byte{64, 1})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
}
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
b := bytes.NewBuffer([]byte{1, 99})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
}
func TestDecodePropertiesGeneralFailure(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadUserProps(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestCopyProperties(t *testing.T) {
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
}
func TestCopyPropertiesNoTransfer(t *testing.T) {
pkA := propertiesStruct
pkB := pkA.Copy(false)
// Properties which should never be transferred from one connection to another
require.Equal(t, uint16(0), pkB.TopicAlias)
}

3802
packets/tpackets.go Normal file

File diff suppressed because it is too large Load Diff

33
packets/tpackets_test.go Normal file
View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func encodeTestOK(wanted TPacketCase) bool {
if wanted.RawBytes == nil {
return false
}
if wanted.Group != "" && wanted.Group != "encode" {
return false
}
return true
}
func decodeTestOK(wanted TPacketCase) bool {
if wanted.Group != "" && wanted.Group != "decode" {
return false
}
return true
}
func TestTPacketCaseGet(t *testing.T) {
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
}

1439
server.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,206 +0,0 @@
package circ
import (
"errors"
"io"
"sync"
"sync/atomic"
)
var (
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
ErrOutOfRange = errors.New("Indexes out of range")
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// buffer contains core values and methods to be included in a reader or writer.
type Buffer struct {
Mu sync.RWMutex // the buffer needs it's own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
done int64 // indicates that the buffer is closed.
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) Buffer {
if size == 0 {
size = DefaultBufferSize
}
if block == 0 {
block = DefaultBlockSize
}
if size < 2*block {
size = 2 * block
}
return Buffer{
size: size,
mask: size - 1,
block: block,
buf: make([]byte, size),
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
}
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := Buffer{
size: l,
mask: l - 1,
block: block,
buf: buf,
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
return b
}
// Get will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
}
// SetPos sets the head and tail of the buffer.
func (b *Buffer) SetPos(tail, head int64) {
atomic.StoreInt64(&b.tail, tail)
atomic.StoreInt64(&b.head, head)
}
// Get returns the internal buffer.
func (b *Buffer) Get() []byte {
b.Mu.Lock()
defer b.Mu.Unlock()
return b.buf
}
// Set writes bytes to a range of indexes in the byte buffer.
func (b *Buffer) Set(p []byte, start, end int) error {
b.Mu.Lock()
defer b.Mu.Unlock()
if end > b.size || start > b.size {
return ErrOutOfRange
}
o := 0
for i := start; i < end; i++ {
b.buf[i] = p[o]
o++
}
return nil
}
// Index returns the buffer-relative index of an integer.
func (b *Buffer) Index(i int64) int {
return b.mask & int(i)
}
// awaitEmpty will block until there is at least n bytes between
// the head and the tail (looking forward).
func (b *Buffer) awaitEmpty(n int) error {
// If the head has wrapped behind the tail, and next will overrun tail,
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadInt64(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
b.rcond.Wait()
}
b.rcond.L.Unlock()
return nil
}
// awaitFilled will block until there are at least n bytes between the
// tail and the head (looking forward).
func (b *Buffer) awaitFilled(n int) error {
// Because awaitCapacity prevents the head from overrunning the t
// able on write, we can simply ensure there is enough space
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadInt64(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
return nil
}
// checkEmpty returns true if there are at least n bytes between the head and
// the tail.
func (b *Buffer) checkEmpty(n int) bool {
head := atomic.LoadInt64(&b.head)
next := head + int64(n)
tail := atomic.LoadInt64(&b.tail)
if next-tail > int64(b.size) {
return false
}
return true
}
// checkFilled returns true if there are at least n bytes between the tail and
// the head.
func (b *Buffer) checkFilled(n int) bool {
if atomic.LoadInt64(&b.tail)+int64(n) <= atomic.LoadInt64(&b.head) {
return true
}
return false
}
// CommitTail moves the tail position of the buffer n bytes.
func (b *Buffer) CommitTail(n int) {
atomic.AddInt64(&b.tail, int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
// CapDelta returns the difference between the head and tail.
func (b *Buffer) CapDelta() int {
return int(atomic.LoadInt64(&b.head) - atomic.LoadInt64(&b.tail))
}
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreInt64(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}

View File

@@ -1,304 +0,0 @@
package circ
import (
//"fmt"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewBuffer(t *testing.T) {
var size int = 16
var block int = 4
buf := NewBuffer(size, block)
require.NotNil(t, buf.buf)
require.NotNil(t, buf.rcond)
require.NotNil(t, buf.wcond)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewBuffer0Size(t *testing.T) {
buf := NewBuffer(0, 0)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBufferSize, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferUndersize(t *testing.T) {
buf := NewBuffer(DefaultBlockSize+10, DefaultBlockSize)
require.NotNil(t, buf.buf)
require.Equal(t, DefaultBlockSize*2, buf.size)
require.Equal(t, DefaultBlockSize, buf.block)
}
func TestNewBufferFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestNewBufferFromSlice0Size(t *testing.T) {
b := NewBytesPool(256)
buf := NewBufferFromSlice(0, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
require.Equal(t, int64(0), tail)
require.Equal(t, int64(0), head)
buf.tail = 3
buf.head = 11
tail, head = buf.GetPos()
require.Equal(t, int64(3), tail)
require.Equal(t, int64(11), head)
}
func TestGet(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, make([]byte, 16), buf.Get())
buf.buf[0] = 1
buf.buf[15] = 1
require.Equal(t, []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, buf.Get())
}
func TestSetPos(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, int64(0), buf.tail)
require.Equal(t, int64(0), buf.head)
buf.SetPos(4, 8)
require.Equal(t, int64(4), buf.tail)
require.Equal(t, int64(8), buf.head)
}
func TestSet(t *testing.T) {
buf := NewBuffer(16, 4)
err := buf.Set([]byte{1, 1, 1, 1}, 17, 19)
require.Error(t, err)
err = buf.Set([]byte{1, 1, 1, 1}, 4, 8)
require.NoError(t, err)
require.Equal(t, []byte{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, buf.buf)
}
func TestIndex(t *testing.T) {
buf := NewBuffer(1024, 4)
require.Equal(t, 512, buf.Index(512))
require.Equal(t, 0, buf.Index(1024))
require.Equal(t, 6, buf.Index(1030))
require.Equal(t, 6, buf.Index(61446))
}
func TestAwaitFilled(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
await int
desc string
}{
{tail: 0, head: 4, n: 4, await: 1, desc: "OK 0, 4"},
{tail: 8, head: 11, n: 4, await: 1, desc: "OK 8, 11"},
{tail: 102, head: 103, n: 4, await: 3, desc: "OK 102, 103"},
}
for i, tt := range tests {
//fmt.Println(i)
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.head, int64(tt.await))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitFilledEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
func TestAwaitEmptyOK(t *testing.T) {
tests := []struct {
tail int64
head int64
await int
desc string
}{
{tail: 0, head: 0, await: 0, desc: "OK 0, 0"},
{tail: 0, head: 5, await: 0, desc: "OK 0, 5"},
{tail: 0, head: 14, await: 3, desc: "OK wrap 0, 14 "},
{tail: 22, head: 35, await: 2, desc: "OK wrap 0, 14 "},
{tail: 15, head: 17, await: 7, desc: "OK 15,2"},
{tail: 0, head: 10, await: 2, desc: "OK 0, 10"},
{tail: 1, head: 15, await: 4, desc: "OK 2, 14"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.AddInt64(&buf.tail, int64(tt.await))
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestAwaitEmptyEnded(t *testing.T) {
buf := NewBuffer(16, 4)
buf.SetPos(1, 15)
o := make(chan error)
go func() {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
require.Error(t, <-o)
}
func TestCheckEmpty(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: true, desc: "0, 0 true"},
{tail: 3, head: 4, want: true, desc: "4, 3 true"},
{tail: 15, head: 17, want: true, desc: "15, 17(1) true"},
{tail: 1, head: 30, want: false, desc: "1, 30(14) false"},
{tail: 15, head: 30, want: false, desc: "15, 30(14) false; head has caught up to tail"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkEmpty(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCheckFilled(t *testing.T) {
buf := NewBuffer(16, 4)
tests := []struct {
head int64
tail int64
want bool
desc string
}{
{tail: 0, head: 0, want: false, desc: "0, 0 false"},
{tail: 0, head: 4, want: true, desc: "0, 4 true"},
{tail: 14, head: 16, want: false, desc: "14,16 false"},
{tail: 14, head: 18, want: true, desc: "14,16 true"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
require.Equal(t, tt.want, buf.checkFilled(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc)
}
}
func TestCommitTail(t *testing.T) {
tests := []struct {
tail int64
head int64
n int
next int64
await int
desc string
}{
{tail: 0, head: 5, n: 4, next: 4, await: 0, desc: "OK 0, 4"},
{tail: 0, head: 5, n: 6, next: 6, await: 1, desc: "OK 0, 5"},
}
for i, tt := range tests {
buf := NewBuffer(16, 4)
buf.SetPos(tt.tail, tt.head)
go func() {
buf.CommitTail(tt.n)
}()
time.Sleep(time.Millisecond)
for j := 0; j < tt.await; j++ {
atomic.AddInt64(&buf.head, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
}
require.Equal(t, tt.next, buf.tail, "Next tail mismatch [i:%d] %s", i, tt.desc)
}
}
/*
func TestCommitTailEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}
*/
func TestCapDelta(t *testing.T) {
buf := NewBuffer(16, 4)
require.Equal(t, 0, buf.CapDelta())
buf.SetPos(10, 15)
require.Equal(t, 5, buf.CapDelta())
}
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, int64(1), buf.done)
}

View File

@@ -1,34 +0,0 @@
package circ
import (
"sync"
)
// BytesPool is a pool of []byte.
type BytesPool struct {
pool sync.Pool
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) BytesPool {
return BytesPool{
pool: sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
},
}
}
// Get returns a pooled bytes.Buffer.
func (b *BytesPool) Get() []byte {
return b.pool.Get().([]byte)
}
// Put puts the byte slice back into the pool.
func (b *BytesPool) Put(x []byte) {
for i := range x {
x[i] = 0
}
b.pool.Put(x)
}

View File

@@ -1,46 +0,0 @@
package circ
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesPool(t *testing.T) {
bpool := NewBytesPool(256)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBytesPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBytesPool(256)
}
}
func TestNewBytesPoolGet(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
}
func BenchmarkBytesPoolGet(b *testing.B) {
bpool := NewBytesPool(256)
for n := 0; n < b.N; n++ {
bpool.Get()
}
}
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
bpool.Put(buf)
}
func BenchmarkBytesPoolPut(b *testing.B) {
bpool := NewBytesPool(256)
buf := bpool.Get()
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}

View File

@@ -1,96 +0,0 @@
package circ
import (
"io"
"sync/atomic"
)
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
Buffer
}
// NewReader returns a new Circular Reader.
func NewReader(size, block int) *Reader {
b := NewBuffer(size, block)
b.ID = "\treader"
return &Reader{
b,
}
}
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
b.ID = "\treader"
return &Reader{
b,
}
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreInt64(&b.State, 1)
defer atomic.StoreInt64(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 {
return total, nil
}
// Wait until there's enough capacity in the buffer before
// trying to read more bytes from the io.Reader.
err := b.awaitEmpty(b.block)
if err != nil {
// b.done is the only error condition for awaitCapacity
// so loop around and return properly.
continue
}
// If the block will overrun the circle end, just fill up
// and collect the rest on the next pass.
start := b.Index(atomic.LoadInt64(&b.head))
end := start + b.block
if end > b.size {
end = b.size
}
// Read into the buffer between the start and end indexes only.
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
return total, nil
}
// Move the head forward however many bytes were read.
atomic.AddInt64(&b.head, int64(n))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}
}
// Read reads n bytes from the buffer, and will block until at n bytes
// exist in the buffer to read.
func (b *Buffer) Read(n int) (p []byte, err error) {
err = b.awaitFilled(n)
if err != nil {
return
}
tail := atomic.LoadInt64(&b.tail)
next := tail + int64(n)
// If the read overruns the buffer, get everything until the end
// and then whatever is left from the start.
if b.Index(tail) > b.Index(next) {
b.tmp = b.buf[b.Index(tail):]
b.tmp = append(b.tmp, b.buf[:b.Index(next)]...)
} else {
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
}
return b.tmp, nil
}

View File

@@ -1,125 +0,0 @@
package circ
import (
"bytes"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewReader(t *testing.T) {
var size = 16
var block = 4
buf := NewReader(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewReaderFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestReadFrom(t *testing.T) {
buf := NewReader(16, 4)
b4 := bytes.Repeat([]byte{'-'}, 4)
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.NoError(t, err)
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.Equal(t, int64(12), buf.head)
}
func TestReadFromWrap(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = bytes.Repeat([]byte{'-'}, 16)
buf.SetPos(8, 14)
br := bytes.NewReader(bytes.Repeat([]byte{'/'}, 8))
o := make(chan error)
go func() {
_, err := buf.ReadFrom(br)
o <- err
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreInt64(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
}()
<-o
require.Equal(t, []byte{'/', '/', '/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '/', '/'}, buf.Get())
require.Equal(t, int64(22), atomic.LoadInt64(&buf.head))
require.Equal(t, 6, buf.Index(atomic.LoadInt64(&buf.head)))
}
func TestReadOK(t *testing.T) {
buf := NewReader(16, 4)
buf.buf = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
tests := []struct {
tail int64
head int64
n int
bytes []byte
desc string
}{
{tail: 0, head: 4, n: 4, bytes: []byte{'a', 'b', 'c', 'd'}, desc: "0, 4 OK"},
{tail: 3, head: 15, n: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, desc: "3, 15 OK"},
{tail: 14, head: 15, n: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, desc: "14, 2 wrapped OK"},
}
for i, tt := range tests {
buf.SetPos(tt.tail, tt.head)
o := make(chan []byte)
go func() {
p, _ := buf.Read(tt.n)
o <- p
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.head, buf.head+int64(tt.n))
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
done := <-o
require.Equal(t, tt.bytes, done, "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestReadEnded(t *testing.T) {
buf := NewBuffer(16, 4)
o := make(chan error)
go func() {
_, err := buf.Read(4)
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
require.Error(t, <-o)
}

View File

@@ -1,107 +0,0 @@
package circ
import (
"fmt"
"io"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
func NewWriter(size, block int) *Writer {
b := NewBuffer(size, block)
b.ID = "writer"
return &Writer{
b,
}
}
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
b.ID = "writer"
return &Writer{
b,
}
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2)
defer atomic.StoreInt64(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
// Read from the buffer until there is at least 1 byte to write.
err = b.awaitFilled(1)
if err != nil {
return
}
// Get all the bytes between the tail and head, wrapping if necessary.
tail := atomic.LoadInt64(&b.tail)
rTail := b.Index(tail)
rHead := b.Index(atomic.LoadInt64(&b.head))
n := b.CapDelta()
p := make([]byte, 0, n)
if rTail > rHead {
p = append(p, b.buf[rTail:]...)
p = append(p, b.buf[:rHead]...)
} else {
p = append(p, b.buf[rTail:rHead]...)
}
//fmt.Println("writing", p)
n, err = w.Write(p)
total += n
if err != nil {
fmt.Println("writing err", err)
return
}
//fmt.Println("written", n)
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
}
// Write writes the buffer to the buffer p, returning the number of bytes written.
func (b *Writer) Write(p []byte) (total int, err error) {
err = b.awaitEmpty(len(p))
if err != nil {
return
}
total = b.writeBytes(p)
atomic.AddInt64(&b.head, int64(total))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
return
}
// writeBytes writes bytes to the buffer from the start position, and returns
// the new head position. This function does not wait for capacity and will
// overwrite any existing bytes.
func (b *Writer) writeBytes(p []byte) int {
var o int
var n int
for i := 0; i < len(p); i++ {
o = b.Index(atomic.LoadInt64(&b.head) + int64(i))
b.buf[o] = p[i]
n++
}
return n
}

View File

@@ -1,155 +0,0 @@
package circ
import (
"bufio"
"bytes"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewWriter(t *testing.T) {
var size = 16
var block = 4
buf := NewWriter(size, block)
require.NotNil(t, buf.buf)
require.Equal(t, size, len(buf.buf))
require.Equal(t, size, buf.size)
require.Equal(t, block, buf.block)
}
func TestNewWriterFromSlice(t *testing.T) {
b := NewBytesPool(256)
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
require.NotNil(t, buf.buf)
require.Equal(t, 256, cap(buf.buf))
}
func TestWriteTo(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
await int
total int
err error
desc string
}{
{tail: 0, head: 5, bytes: []byte{'a', 'b', 'c', 'd', 'e'}, desc: "0,5 OK"},
{tail: 14, head: 21, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd', 'e'}, desc: "14,16(2) OK"},
}
for i, tt := range tests {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(tt.tail, tt.head)
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreInt64(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
w.Flush()
require.Equal(t, tt.bytes, b.Bytes(), "Written bytes mismatch [i:%d] %s", i, tt.desc)
}
}
func TestWriteToEndedFirst(t *testing.T) {
buf := NewWriter(16, 4)
buf.done = 1
var b bytes.Buffer
w := bufio.NewWriter(&b)
_, err := buf.WriteTo(w)
require.Error(t, err)
}
func TestWriteToBadWriter(t *testing.T) {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
buf := NewWriter(16, 4)
buf.Set(bb, 0, 16)
buf.SetPos(0, 6)
r, w := net.Pipe()
w.Close()
_, err := buf.WriteTo(w)
require.Error(t, err)
r.Close()
}
func TestWrite(t *testing.T) {
tests := []struct {
tail int64
head int64
rHead int64
bytes []byte
want []byte
desc string
}{
{tail: 0, head: 0, rHead: 4, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, desc: "0>4 OK"},
{tail: 4, head: 14, rHead: 2, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 'b'}, desc: "14>2 OK"},
}
for i, tt := range tests {
buf := NewWriter(16, 4)
buf.SetPos(tt.tail, tt.head)
o := make(chan []interface{})
go func() {
nn, err := buf.Write(tt.bytes)
o <- []interface{}{nn, err}
}()
done := <-o
require.Equal(t, tt.want, buf.buf, "Wanted written mismatch [i:%d] %s", i, tt.desc)
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
}
}
func TestWriteEnded(t *testing.T) {
buf := NewWriter(16, 4)
buf.SetPos(15, 30)
buf.done = 1
_, err := buf.Write([]byte{'a', 'b', 'c', 'd'})
require.Error(t, err)
}
func TestWriteBytes(t *testing.T) {
tests := []struct {
tail int64
head int64
bytes []byte
want []byte
start int
desc string
}{
{tail: 0, head: 0, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0}, desc: "0,4 OK"},
{tail: 6, head: 6, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 'a', 'b'}, desc: "6,2 OK wrapped"},
}
for i, tt := range tests {
buf := NewWriter(8, 4)
buf.SetPos(tt.tail, tt.head)
n := buf.writeBytes(tt.bytes)
require.Equal(t, tt.want, buf.buf, "Buffer mistmatch [i:%d] %s", i, tt.desc)
require.Equal(t, len(tt.bytes), n)
}
}

View File

@@ -1,514 +0,0 @@
package clients
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
var (
// defaultKeepalive is the default connection keepalive value in seconds.
defaultKeepalive uint16 = 10
ErrConnectionClosed = errors.New("Connection not open")
)
// Clients contains a map of the clients known by the broker.
type Clients struct {
sync.RWMutex
internal map[string]*Client // clients known by the broker, keyed on client id.
}
// New returns an instance of Clients.
func New() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
cl.internal[val.ID] = val
cl.Unlock()
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
val, ok := cl.internal[id]
cl.RUnlock()
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
val := len(cl.internal)
cl.RUnlock()
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
delete(cl.internal, id)
cl.Unlock()
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
cl.RUnlock()
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
sync.RWMutex
conn net.Conn // the net.Conn used to establish the connection.
r *circ.Reader // a reader for reading incoming bytes.
w *circ.Writer // a writer for writing outgoing bytes.
ID string // the client id.
AC auth.Controller // an auth controller inherited from the listener.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
Listener string // the id of the listener the client is connected to.
Inflight Inflight // a map of in-flight qos messages.
Username []byte // the username the client authenticated with.
keepalive uint16 // the number of seconds the connection can wait.
cleanSession bool // indicates if the client expects a clean-session.
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
}
// State tracks the state of the client.
type State struct {
Done int64 // atomic counter which indicates that the client has closed.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
endOnce sync.Once // only end once.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
started: new(sync.WaitGroup),
endedW: new(sync.WaitGroup),
endedR: new(sync.WaitGroup),
},
}
cl.refreshDeadline(cl.keepalive)
return cl
}
// NewClientStub returns an instance of Client with basic initializations. This
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
State: State{
Done: 1,
},
}
}
// Identify sets the identification values of a client instance.
func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
cl.Listener = lid
cl.AC = ac
cl.ID = pk.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String()
}
cl.r.ID = cl.ID + " READER"
cl.w.ID = cl.ID + " WRITER"
cl.Username = pk.Username
cl.cleanSession = pk.CleanSession
cl.keepalive = pk.Keepalive
if pk.WillFlag {
cl.LWT = LWT{
Topic: pk.WillTopic,
Message: pk.WillMessage,
Qos: pk.WillQos,
Retain: pk.WillRetain,
}
}
cl.refreshDeadline(cl.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.conn != nil {
var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
}
cl.conn.SetDeadline(expiry)
}
}
// NextPacketID returns the next packet id for a client, looping back to 0
// if the maximum ID has been reached.
func (cl *Client) NextPacketID() uint32 {
i := atomic.LoadUint32(&cl.packetID)
if i == uint32(65535) || i == uint32(0) {
atomic.StoreUint32(&cl.packetID, 1)
return 1
}
return atomic.AddUint32(&cl.packetID, 1)
}
// NoteSubscription makes a note of a subscription for the client.
func (cl *Client) NoteSubscription(filter string, qos byte) {
cl.Lock()
cl.Subscriptions[filter] = qos
cl.Unlock()
}
// ForgetSubscription forgests a subscription note for the client.
func (cl *Client) ForgetSubscription(filter string) {
cl.Lock()
delete(cl.Subscriptions, filter)
cl.Unlock()
}
// Start begins the client goroutines reading and writing packets.
func (cl *Client) Start() {
cl.State.started.Add(2)
go func() {
cl.State.started.Done()
cl.w.WriteTo(cl.conn)
cl.State.endedW.Done()
cl.Stop()
}()
cl.State.endedW.Add(1)
go func() {
cl.State.started.Done()
cl.r.ReadFrom(cl.conn)
cl.State.endedR.Done()
cl.Stop()
}()
cl.State.endedR.Add(1)
cl.State.started.Wait()
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.r.Stop()
cl.w.Stop()
cl.State.endedW.Wait()
cl.conn.Close()
cl.State.endedR.Wait()
atomic.StoreInt64(&cl.State.Done, 1)
})
}
// readFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
p, err := cl.r.Read(1)
if err != nil {
return err
}
err = fh.Decode(p[0])
if err != nil {
return err
}
// The remaining length value can be up to 5 bytes. Read through each byte
// looking for continue values, and if found increase the read. Otherwise
// decode the bytes that were legit.
buf := make([]byte, 0, 6)
i := 1
n := 2
for ; n < 6; n++ {
p, err = cl.r.Read(n)
if err != nil {
return err
}
buf = append(buf, p[i])
// If it's not a continuation flag, end here.
if p[i] < 128 {
break
}
// If i has reached 4 without a length terminator, return a protocol violation.
i++
if i == 4 {
return packets.ErrOversizedLengthIndicator
}
}
// Calculate and store the remaining length of the packet payload.
rem, _ := binary.Uvarint(buf)
fh.Remaining = int(rem)
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
return nil
}
// Read reads new packets from a client connection
func (cl *Client) Read(h func(*Client, packets.Packet) error) error {
for {
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
return nil
}
cl.refreshDeadline(cl.keepalive)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = h(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
return
}
p, err := cl.r.Read(pk.FixedHeader.Remaining)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Disconnect:
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
cl.r.CommitTail(pk.FixedHeader.Remaining)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadInt64(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
cl.w.Mu.Lock()
defer cl.w.Mu.Unlock()
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
default:
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
if err != nil {
return
}
n, err = cl.w.Write(buf.Bytes())
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
return
}
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// InflightMessage contains data about a packet which is currently in-flight.
type InflightMessage struct {
Packet packets.Packet // the packet currently in-flight.
Sent int64 // the last time the message was sent (for retries) in unixtime.
Resends int // the number of times the message was attempted to be sent.
}
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]InflightMessage // internal contains the inflight messages.
}
// Set stores the packet of an Inflight message, keyed on message id. Returns
// true if the inflight message was new.
func (i *Inflight) Set(key uint16, in InflightMessage) bool {
i.Lock()
_, ok := i.internal[key]
i.internal[key] = in
i.Unlock()
return !ok
}
// Get returns the value of an in-flight message if it exists.
func (i *Inflight) Get(key uint16) (InflightMessage, bool) {
i.RLock()
val, ok := i.internal[key]
i.RUnlock()
return val, ok
}
// Len returns the size of the in-flight messages map.
func (i *Inflight) Len() int {
i.RLock()
v := len(i.internal)
i.RUnlock()
return v
}
// GetAll returns all the in-flight messages.
func (i *Inflight) GetAll() map[uint16]InflightMessage {
i.RLock()
defer i.RUnlock()
return i.internal
}
// Delete removes an in-flight message from the map. Returns true if the
// message existed.
func (i *Inflight) Delete(key uint16) bool {
i.Lock()
_, ok := i.internal[key]
delete(i.internal, key)
i.Unlock()
return ok
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,383 +0,0 @@
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func BenchmarkBytesToString(b *testing.B) {
for n := 0; n < b.N; n++ {
bytesToString([]byte{'a', 'b', 'c'})
}
}
func TestDecodeString(t *testing.T) {
expect := []struct {
rawBytes []byte
result []string
offset int
shouldFail bool
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: []string{"a/b/c/d", "a"},
},
{
offset: 14,
rawBytes: []byte{
byte(Connect << 4), 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: []string{"hey"},
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"},
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: []string{"x/y/z", "!@#$%^&"},
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
result: []string{"a/b/c/d", "z"},
shouldFail: true,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
result: []string{"a/b/c/d", "x"},
shouldFail: true,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
result: []string{"a/b/c/d", "y"},
shouldFail: true,
},
{
offset: 17,
rawBytes: []byte{
byte(Connect << 4), 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
result: []string{"lwt"},
shouldFail: true,
},
}
for i, wanted := range expect {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding string [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding string [i:%d]", i)
require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i)
}
}
func BenchmarkDecodeString(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeString(in, 0)
}
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail bool
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 8,
shouldFail: true,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
},
}
for i, wanted := range expect {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding bytes [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding bytes [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
}
}
func BenchmarkDecodeBytes(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}
for n := 0; n < b.N; n++ {
decodeBytes(in, 0)
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail bool
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
result: uint8(0x00),
offset: 8,
shouldFail: true,
},
}
for i, wanted := range expect {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding byte [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i)
}
}
func BenchmarkDecodeByte(b *testing.B) {
in := []byte{0, 4, 77, 81, 84, 84}
for n := 0; n < b.N; n++ {
decodeByte(in, 0)
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail bool
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
result: uint16(0x761),
offset: 8,
shouldFail: true,
},
}
for i, wanted := range expect {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding uint16 [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding uint16 [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i)
}
}
func BenchmarkDecodeUint16(b *testing.B) {
in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}
for n := 0; n < b.N; n++ {
decodeUint16(in, 0)
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail bool
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: true,
},
}
for i, wanted := range expect {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte bool [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding byte bool [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i)
}
}
func BenchmarkDecodeByteBool(b *testing.B) {
in := []byte{0x00, 0x00}
for n := 0; n < b.N; n++ {
decodeByteBool(in, 0)
}
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result, "Incorrect encoded value; not true")
result = encodeBool(false)
require.Equal(t, byte(0), result, "Incorrect encoded value; not false")
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBool(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeBool(true)
}
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value")
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value")
}
func BenchmarkEncodeBytes(b *testing.B) {
bb := []byte("testing")
for n := 0; n < b.N; n++ {
encodeBytes(bb)
}
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0")
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767")
result = encodeUint16(65535)
require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535")
}
func BenchmarkEncodeUint16(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeUint16(32767)
}
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing")
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null")
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a")
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b")
}
func BenchmarkEncodeString(b *testing.B) {
for n := 0; n < b.N; n++ {
encodeString("benchmarking")
}
}

View File

@@ -1,59 +0,0 @@
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
}
// decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(headerByte byte) error {
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate.
fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag.
fh.Retain = headerByte&0x01 > 0 // Extract retain flag.
case Pubrel:
fh.Qos = (headerByte >> 1) & 0x03
case Subscribe:
fh.Qos = (headerByte >> 1) & 0x03
case Unsubscribe:
fh.Qos = (headerByte >> 1) & 0x03
default:
if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 {
return ErrInvalidFlags
}
}
return nil
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int) {
for {
digit := byte(length % 128)
length /= 128
if length > 0 {
digit |= 0x80
}
buf.WriteByte(digit)
if length == 0 {
break
}
}
}

View File

@@ -1,220 +0,0 @@
package packets
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
rawBytes []byte
header FixedHeader
packetError bool
flagError bool
}
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Connack, false, 0, false, 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Publish, false, 0, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Publish, false, 1, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 1, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Publish, false, 2, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
flagError: true,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.flagError == false {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderEncode(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
fixedHeaderExpected[0].header.Encode(buf)
}
}
func TestFixedHeaderDecode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.flagError {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkFixedHeaderDecode(b *testing.B) {
fh := new(FixedHeader)
for n := 0; n < b.N; n++ {
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
if err != nil {
panic(err)
}
}
}
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int
want []byte
}{
{
120,
[]byte{0x78},
},
{
math.MaxInt64,
[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
},
}
for i, wanted := range tt {
buf := new(bytes.Buffer)
encodeLength(buf, wanted.have)
require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have)
}
}
func BenchmarkEncodeLength(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
encodeLength(buf, 120)
}
}

View File

@@ -1,673 +0,0 @@
package packets
import (
"bytes"
"errors"
)
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Accepted byte = 0x00
Failed byte = 0xFF
CodeConnectBadProtocolVersion byte = 0x01
CodeConnectBadClientID byte = 0x02
CodeConnectServerUnavailable byte = 0x03
CodeConnectBadAuthValues byte = 0x04
CodeConnectNotAuthorised byte = 0x05
CodeConnectNetworkError byte = 0xFE
CodeConnectProtocolViolation byte = 0xFF
ErrSubAckNetworkError byte = 0x80
)
var (
// CONNECT
ErrMalformedProtocolName = errors.New("malformed packet: protocol name")
ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version")
ErrMalformedFlags = errors.New("malformed packet: flags")
ErrMalformedKeepalive = errors.New("malformed packet: keepalive")
ErrMalformedClientID = errors.New("malformed packet: client id")
ErrMalformedWillTopic = errors.New("malformed packet: will topic")
ErrMalformedWillMessage = errors.New("malformed packet: will message")
ErrMalformedUsername = errors.New("malformed packet: username")
ErrMalformedPassword = errors.New("malformed packet: password")
// CONNACK
ErrMalformedSessionPresent = errors.New("malformed packet: session present")
ErrMalformedReturnCode = errors.New("malformed packet: return code")
// PUBLISH
ErrMalformedTopic = errors.New("malformed packet: topic name")
ErrMalformedPacketID = errors.New("malformed packet: packet id")
// SUBSCRIBE
ErrMalformedQoS = errors.New("malformed packet: qos")
// PACKETS
ErrProtocolViolation = errors.New("protocol violation")
ErrOffsetStrOutOfRange = errors.New("offset string out of range")
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
ErrOffsetUintOutOfRange = errors.New("offset uint out of range")
ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8")
ErrInvalidFlags = errors.New("invalid flags set for packet")
ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator")
ErrMissingPacketID = errors.New("missing packet id")
ErrSurplusPacketID = errors.New("surplus packet id")
)
// Packet is an MQTT packet. Instead of providing a packet interface and variant
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
PacketID uint16
// Connect
ProtocolName []byte
ProtocolVersion byte
CleanSession bool
WillFlag bool
WillQos byte
WillRetain bool
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
Keepalive uint16
ClientIdentifier string
WillTopic string
WillMessage []byte
Username []byte
Password []byte
// Connack
SessionPresent bool
ReturnCode byte
// Publish
TopicName string
Payload []byte
// Subscribe, Unsubscribe
Topics []string
Qoss []byte
ReturnCodes []byte // Suback
}
// ConnectEncode encodes a connect packet.
func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
protoName := encodeBytes(pk.ProtocolName)
protoVersion := pk.ProtocolVersion
flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7
keepalive := encodeUint16(pk.Keepalive)
clientID := encodeString(pk.ClientIdentifier)
var willTopic, willFlag, usernameFlag, passwordFlag []byte
// If will flag is set, add topic and message.
if pk.WillFlag {
willTopic = encodeString(pk.WillTopic)
willFlag = encodeBytes(pk.WillMessage)
}
// If username flag is set, add username.
if pk.UsernameFlag {
usernameFlag = encodeBytes(pk.Username)
}
// If password flag is set, add password.
if pk.PasswordFlag {
passwordFlag = encodeBytes(pk.Password)
}
// Get a length for the connect header. This is not super pretty, but it works.
pk.FixedHeader.Remaining =
len(protoName) + 1 + 1 + len(keepalive) + len(clientID) +
len(willTopic) + len(willFlag) +
len(usernameFlag) + len(passwordFlag)
pk.FixedHeader.Encode(buf)
// Eschew magic for readability.
buf.Write(protoName)
buf.WriteByte(protoVersion)
buf.WriteByte(flag)
buf.Write(keepalive)
buf.Write(clientID)
buf.Write(willTopic)
buf.Write(willFlag)
buf.Write(usernameFlag)
buf.Write(passwordFlag)
return nil
}
// ConnectDecode decodes a connect packet.
func (pk *Packet) ConnectDecode(buf []byte) error {
var offset int
var err error
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
if err != nil {
return ErrMalformedProtocolName
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedProtocolVersion
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return ErrMalformedFlags
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
pk.WillFlag = 1&(flags>>2) > 0
pk.WillQos = 3 & (flags >> 3) // this one is not a bool
pk.WillRetain = 1&(flags>>5) > 0
pk.PasswordFlag = 1&(flags>>6) > 0
pk.UsernameFlag = 1&(flags>>7) > 0
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedKeepalive
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedClientID
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedWillTopic
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedWillMessage
}
}
// Get username and password if applicable.
if pk.UsernameFlag {
pk.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
}
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedPassword
}
}
return nil
}
// ConnectValidate ensures the connect packet is compliant.
func (pk *Packet) ConnectValidate() (b byte, err error) {
// End if protocol name is bad.
if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 &&
bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if protocol version is bad.
if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) ||
(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) {
return CodeConnectBadProtocolVersion, ErrProtocolViolation
}
// End if reserved bit is not 0.
if pk.ReservedBit != 0 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if ClientID is too long.
if len(pk.ClientIdentifier) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if password flag is set without a username.
if pk.PasswordFlag && !pk.UsernameFlag {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if Username or Password is too long.
if len(pk.Username) > 65535 || len(pk.Password) > 65535 {
return CodeConnectProtocolViolation, ErrProtocolViolation
}
// End if client id isn't set and clean session is false.
if !pk.CleanSession && len(pk.ClientIdentifier) == 0 {
return CodeConnectBadClientID, ErrProtocolViolation
}
return Accepted, nil
}
// ConnackEncode encodes a Connack packet.
func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.WriteByte(encodeBool(pk.SessionPresent))
buf.WriteByte(pk.ReturnCode)
return nil
}
// ConnackDecode decodes a Connack packet.
func (pk *Packet) ConnackDecode(buf []byte) error {
var offset int
var err error
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
if err != nil {
return ErrMalformedSessionPresent
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedReturnCode
}
return nil
}
// DisconnectEncode encodes a Disconnect packet.
func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingreqEncode encodes a Pingreq packet.
func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PingrespEncode encodes a Pingresp packet.
func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Encode(buf)
return nil
}
// PubackEncode encodes a Puback packet.
func (pk *Packet) PubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubackDecode decodes a Puback packet.
func (pk *Packet) PubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
return nil
}
// PubcompEncode encodes a Pubcomp packet.
func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubcompDecode decodes a Pubcomp packet.
func (pk *Packet) PubcompDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
return nil
}
// PublishEncode encodes a Publish packet.
func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
topicName := encodeString(pk.TopicName)
var packetID []byte
// Add PacketID if QOS is set.
// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos > 0 {
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID = encodeUint16(pk.PacketID)
}
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
pk.FixedHeader.Encode(buf)
buf.Write(topicName)
buf.Write(packetID)
buf.Write(pk.Payload)
return nil
}
// PublishDecode extracts the data values from the packet.
func (pk *Packet) PublishDecode(buf []byte) error {
var offset int
var err error
pk.TopicName, offset, err = decodeString(buf, 0)
if err != nil {
return ErrMalformedTopic
}
// If QOS decode Packet ID.
if pk.FixedHeader.Qos > 0 {
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
}
}
pk.Payload = buf[offset:]
return nil
}
// PublishCopy creates a new instance of Publish packet bearing the
// same payload and destination topic, but with an empty header for
// inheriting new QoS flags, etc.
func (pk *Packet) PublishCopy() Packet {
return Packet{
FixedHeader: FixedHeader{
Type: Publish,
Retain: pk.FixedHeader.Retain,
},
TopicName: pk.TopicName,
Payload: pk.Payload,
}
}
// PublishValidate validates a publish packet.
func (pk *Packet) PublishValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1]
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
// @SPEC [MQTT-2.3.1-5]
// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0.
if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 {
return Failed, ErrSurplusPacketID
}
return Accepted, nil
}
// PubrecEncode encodes a Pubrec packet.
func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrecDecode decodes a Pubrec packet.
func (pk *Packet) PubrecDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
return nil
}
// PubrelEncode encodes a Pubrel packet.
func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// PubrelDecode decodes a Pubrel packet.
func (pk *Packet) PubrelDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
return nil
}
// SubackEncode encodes a Suback packet.
func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
packetID := encodeUint16(pk.PacketID)
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
pk.FixedHeader.Encode(buf)
buf.Write(packetID) // Encode Packet ID.
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.
return nil
}
// SubackDecode decodes a Suback packet.
func (pk *Packet) SubackDecode(buf []byte) error {
var offset int
var err error
// Get Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
}
// Get Granted QOS flags.
pk.ReturnCodes = buf[offset:]
return nil
}
// SubscribeEncode encodes a Subscribe packet.
func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths and associated QOS flags.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic)) + 1
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names and associated QOS flags.
for i, topic := range pk.Topics {
buf.Write(encodeString(topic))
buf.WriteByte(pk.Qoss[i])
}
return nil
}
// SubscribeDecode decodes a Subscribe packet.
func (pk *Packet) SubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
// Keep decoding until there's no space left.
for offset < len(buf) {
// Decode Topic Name.
var topic string
topic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedTopic
}
pk.Topics = append(pk.Topics, topic)
// Decode QOS flag.
var qos byte
qos, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedQoS
}
// Ensure QoS byte is within range.
if !(qos >= 0 && qos <= 2) {
//if !validateQoS(qos) {
return ErrMalformedQoS
}
pk.Qoss = append(pk.Qoss, qos)
}
return nil
}
// SubscribeValidate ensures the packet is compliant.
func (pk *Packet) SubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}
// UnsubackEncode encodes an Unsuback packet.
func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}
// UnsubackDecode decodes an Unsuback packet.
func (pk *Packet) UnsubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
return nil
}
// UnsubscribeEncode encodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
// Add the Packet ID.
// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.PacketID == 0 {
return ErrMissingPacketID
}
packetID := encodeUint16(pk.PacketID)
// Count topics lengths.
var topicsLen int
for _, topic := range pk.Topics {
topicsLen += len(encodeString(topic))
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names.
for _, topic := range pk.Topics {
buf.Write(encodeString(topic))
}
return nil
}
// UnsubscribeDecode decodes an Unsubscribe packet.
func (pk *Packet) UnsubscribeDecode(buf []byte) error {
var offset int
var err error
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
}
// Keep decoding until there's no space left.
for offset < len(buf) {
var t string
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
if err != nil {
return ErrMalformedTopic
}
if len(t) > 0 {
pk.Topics = append(pk.Topics, t)
}
}
return nil
}
// UnsubscribeValidate validates an Unsubscribe packet.
func (pk *Packet) UnsubscribeValidate() (byte, error) {
// @SPEC [MQTT-2.3.1-1].
// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier.
if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 {
return Failed, ErrMissingPacketID
}
return Accepted, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,343 +0,0 @@
package topics
import (
"strings"
"sync"
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions map[string]byte
// Index is a prefix/trie tree containing topic subscribers and retained messages.
type Index struct {
mu sync.RWMutex // a mutex for locking the whole index.
Root *Leaf // a leaf containing a message and more leaves.
}
// New returns a pointer to a new instance of Index.
func New() *Index {
return &Index{
Root: &Leaf{
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
},
}
}
// RetainMessage saves a message payload to the end of a topic branch. Returns
// 1 if a retained message was added, 0 if there was no change, and -1 if the
// retained message was removed.
func (x *Index) RetainMessage(msg packets.Packet) int64 {
var q int64
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(msg.TopicName)
if len(msg.Payload) > 0 {
if n.Message.FixedHeader.Retain == false {
q = 1
}
n.Message = msg
} else {
if n.Message.FixedHeader.Retain == true {
q = -1
}
x.unpoperate(msg.TopicName, "", true)
}
return q
}
// Subscribe creates a subscription filter for a client. Returns true if the
// subscription was new.
func (x *Index) Subscribe(filter, client string, qos byte) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
n.Clients[client] = qos
n.Filter = filter
return !ok
}
// Unsubscribe removes a subscription filter for a client. Returns true if an
// unsubscribe action successful and the subscription existed.
func (x *Index) Unsubscribe(filter, client string) bool {
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(filter)
_, ok := n.Clients[client]
return x.unpoperate(filter, client, false) && ok
}
// unpoperate steps backward through a trie sequence and removes any orphaned
// nodes. If a client id is specified, it will unsubscribe a client. If message
// is true, it will delete a retained message.
func (x *Index) unpoperate(filter string, client string, message bool) bool {
var d int // Walk to end leaf.
var particle string
var hasNext = true
e := x.Root
for hasNext {
particle, hasNext = isolateParticle(filter, d)
d++
e, _ = e.Leaves[particle]
// If the topic part doesn't exist in the tree, there's nothing
// left to do.
if e == nil {
return false
}
}
// Step backward removing client and orphaned leaves.
var key string
var orphaned bool
var end = true
for e.Parent != nil {
key = e.Key
// Wipe the client from this leaf if it's the filter end.
if end {
if client != "" {
delete(e.Clients, client)
}
if message {
e.Message = packets.Packet{}
}
end = false
}
// If this leaf is empty, note it as orphaned.
orphaned = len(e.Clients) == 0 && len(e.Leaves) == 0 && !e.Message.FixedHeader.Retain
// Traverse up the branch.
e = e.Parent
// If the leaf we just came from was empty, delete it.
if orphaned {
delete(e.Leaves, key)
}
}
return true
}
// poperate iterates and populates through a topic/filter path, instantiating
// leaves as it goes and returning the final leaf in the branch.
// poperate is a more enjoyable word than iterpop.
func (x *Index) poperate(topic string) *Leaf {
var d int
var particle string
var hasNext = true
n := x.Root
for hasNext {
particle, hasNext = isolateParticle(topic, d)
d++
child, _ := n.Leaves[particle]
if child == nil {
child = &Leaf{
Key: particle,
Parent: n,
Leaves: make(map[string]*Leaf),
Clients: make(map[string]byte),
}
n.Leaves[particle] = child
}
n = child
}
return n
}
// Subscribers returns a map of clients who are subscribed to matching filters.
func (x *Index) Subscribers(topic string) Subscriptions {
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanSubscribers(topic, 0, make(Subscriptions))
}
// Messages returns a slice of retained topic messages which match a filter.
func (x *Index) Messages(filter string) []packets.Packet {
// ReLeaf("messages", x.Root, 0)
x.mu.RLock()
defer x.mu.RUnlock()
return x.Root.scanMessages(filter, 0, make([]packets.Packet, 0, 32))
}
// Leaf is a child node on the tree.
type Leaf struct {
Key string // the key that was used to create the leaf.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
Filter string // the path of the topic filter being matched.
Message packets.Packet // a message which has been retained for a specific topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who
// have subscription filters matching a topic, and their highest QoS byte.
func (l *Leaf) scanSubscribers(topic string, d int, clients Subscriptions) Subscriptions {
part, hasNext := isolateParticle(topic, d)
// For either the topic part, a +, or a #, follow the branch.
for _, particle := range []string{part, "+", "#"} {
// Topics beginning with the reserved $ character are restricted from
// being returned for top level wildcards.
if d == 0 && len(part) > 0 && part[0] == '$' && (particle == "+" || particle == "#") {
continue
}
if child, ok := l.Leaves[particle]; ok {
// We're only interested in getting clients from the final
// element in the topic, or those with wildhashes.
if !hasNext || particle == "#" {
// Capture the highest QOS byte for any client with a filter
// matching the topic.
for client, qos := range child.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
// Make sure we also capture any client who are listening
// to this topic via path/#
if !hasNext {
if extra, ok := child.Leaves["#"]; ok {
for client, qos := range extra.Clients {
if ex, ok := clients[client]; !ok || ex < qos {
clients[client] = qos
}
}
}
}
}
// If this branch has hit a wildhash, just return immediately.
if particle == "#" {
return clients
} else if hasNext {
clients = child.scanSubscribers(topic, d+1, clients)
}
}
}
return clients
}
// scanMessages recursively steps through a branch of leaves finding retained messages
// that match a topic filter. Setting `d` to -1 will enable wildhash mode, and will
// recursively check ALL child leaves in every subsequent branch.
func (l *Leaf) scanMessages(filter string, d int, messages []packets.Packet) []packets.Packet {
// If a wildhash mode has been set, continue recursively checking through all
// child leaves regardless of their particle key.
if d == -1 {
for _, child := range l.Leaves {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
messages = child.scanMessages(filter, -1, messages)
}
return messages
}
// Otherwise, we'll get the particle for d in the filter.
particle, hasNext := isolateParticle(filter, d)
// If there's no more particles after this one, then take the messages from
// these topics.
if !hasNext {
// Wildcards and Wildhashes must be checked first, otherwise they
// may be detected as standard particles, and not act properly.
if particle == "+" || particle == "#" {
// Otherwise, if it's a wildcard or wildhash, get messages from all
// the child leaves. This wildhash captures messages on the actual
// wildhash position, whereas the d == -1 block collects subsequent
// messages further down the branch.
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else if child, ok := l.Leaves[particle]; ok {
if child.Message.FixedHeader.Retain {
messages = append(messages, child.Message)
}
}
} else {
// If it's not the last particle, branch out to the next leaves, scanning
// all available if it's a wildcard, or just one if it's a specific particle.
if particle == "+" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, d+1, messages)
}
} else if child, ok := l.Leaves[particle]; ok {
messages = child.scanMessages(filter, d+1, messages)
}
}
// If the particle was a wildhash, scan all the child leaves setting the
// d value to wildhash mode.
if particle == "#" {
for _, child := range l.Leaves {
if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' {
continue
}
messages = child.scanMessages(filter, -1, messages)
}
}
return messages
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
if d > -1 && i == d && end > -1 {
hasNext = true
particle = filter[next:end]
} else if end > -1 {
hasNext = false
filter = filter[end+1:]
} else {
hasNext = false
particle = filter[next:]
}
}
return
}
// ReLeaf is a dev function for showing the trie leafs.
/*
func ReLeaf(m string, leaf *Leaf, d int) {
for k, v := range leaf.Leaves {
fmt.Println(m, d, strings.Repeat(" ", d), k)
ReLeaf(m, v, d+1)
}
}
*/

View File

@@ -1,484 +0,0 @@
package topics
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/internal/packets"
)
func TestNew(t *testing.T) {
index := New()
require.NotNil(t, index)
require.NotNil(t, index.Root)
}
func BenchmarkNew(b *testing.B) {
for n := 0; n < b.N; n++ {
New()
}
}
func TestPoperate(t *testing.T) {
index := New()
child := index.poperate("path/to/my/mqtt")
require.Equal(t, "mqtt", child.Key)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
child = index.poperate("a/b/c/d/e")
require.Equal(t, "e", child.Key)
child = index.poperate("a/b/c/c/a")
require.Equal(t, "a", child.Key)
}
func BenchmarkPoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestUnpoperate(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
pk := packets.Packet{TopicName: "path/to/retained/message", Payload: []byte{'h', 'e', 'l', 'l', 'o'}}
index.RetainMessage(pk)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"].Message)
pk2 := packets.Packet{TopicName: "path/to/my/mqtt", Payload: []byte{'s', 'h', 'a', 'r', 'e', 'd'}}
index.RetainMessage(pk2)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.unpoperate("path/to/my/mqtt", "", true) // delete retained
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message.FixedHeader.Retain)
index.unpoperate("path/to/my/mqtt", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
index.unpoperate("path/to/retained/message", "", true) // delete retained
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves, "my")
index.unpoperate("path/to/whatever", "client-1", false) // unsubscribe client
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
//require.Empty(t, index.Root.Leaves["path"])
}
func BenchmarkUnpoperate(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.poperate("path/to/my/mqtt")
}
}
func TestRetainMessage(t *testing.T) {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/my/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
TopicName: "path/to/another/mqtt",
Payload: []byte{'h', 'e', 'l', 'l', 'o'},
}
index := New()
q := index.RetainMessage(pk)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients["client-1"])
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
q = index.RetainMessage(pk2) // already exsiting
require.Equal(t, int64(0), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
// Delete retained
pk3 := packets.Packet{TopicName: "path/to/another/mqtt", Payload: []byte{}}
q = index.RetainMessage(pk3)
require.Equal(t, int64(-1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
}
func BenchmarkRetainMessage(b *testing.B) {
index := New()
pk := packets.Packet{TopicName: "path/to/another/mqtt"}
for n := 0; n < b.N; n++ {
index.RetainMessage(pk)
}
}
func TestSubscribeOK(t *testing.T) {
index := New()
q := index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/my/mqtt", "client-1", 0)
require.Equal(t, false, q)
q = index.Subscribe("path/to/my/mqtt", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/to/another/mqtt", "client-1", 0)
require.Equal(t, true, q)
q = index.Subscribe("path/+", "client-2", 0)
require.Equal(t, true, q)
q = index.Subscribe("#", "client-3", 0)
require.Equal(t, true, q)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Equal(t, "path/to/my/mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Filter)
require.Equal(t, "mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Key)
require.Equal(t, index.Root.Leaves["path"], index.Root.Leaves["path"].Leaves["to"].Parent)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["+"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
}
func BenchmarkSubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/mqtt/basic", "client-1", 0)
}
}
func TestUnsubscribeA(t *testing.T) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("path/to/stuff", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok := index.Unsubscribe("path/to/my/mqtt", "client-1")
require.Equal(t, true, ok)
require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"])
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1")
ok = index.Unsubscribe("path/to/stuff", "client-1")
require.Equal(t, true, ok)
require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1")
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2")
require.Contains(t, index.Root.Leaves["#"].Clients, "client-3")
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "client-1")
require.Equal(t, false, ok)
}
func TestUnsubscribeCascade(t *testing.T) {
index := New()
index.Subscribe("a/b/c", "client-1", 0)
index.Subscribe("a/b/c/e/e", "client-1", 0)
ok := index.Unsubscribe("a/b/c/e/e", "client-1")
require.Equal(t, true, ok)
require.NotEmpty(t, index.Root.Leaves)
require.Contains(t, index.Root.Leaves["a"].Leaves["b"].Leaves["c"].Clients, "client-1")
}
// This benchmark is Unsubscribe-Subscribe
func BenchmarkUnsubscribe(b *testing.B) {
index := New()
for n := 0; n < b.N; n++ {
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Unsubscribe("path/to/mqtt/basic", "client-1")
}
}
func TestSubscribersFind(t *testing.T) {
tt := []struct {
filter string
topic string
len int
}{
{
filter: "a",
topic: "a",
len: 1,
},
{
filter: "a/",
topic: "a",
len: 0,
},
{
filter: "a/",
topic: "a/",
len: 1,
},
{
filter: "/a",
topic: "/a",
len: 1,
},
{
filter: "path/to/my/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/to/+/mqtt",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/+",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "+/+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "zen/#",
topic: "zen",
len: 1,
},
{
filter: "+/+/#",
topic: "path/to/my/mqtt",
len: 1,
},
{
filter: "path/to/",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "#/stuff",
topic: "path/to/my/mqtt",
len: 0,
},
{
filter: "$SYS/#",
topic: "$SYS/info",
len: 1,
},
{
filter: "#",
topic: "$SYS/info",
len: 0,
},
{
filter: "+/info",
topic: "$SYS/info",
len: 0,
},
}
for i, check := range tt {
index := New()
index.Subscribe(check.filter, "client-1", 0)
clients := index.Subscribers(check.topic)
//spew.Dump(clients)
require.Equal(t, check.len, len(clients), "Unexpected clients len at %d %s %s", i, check.filter, check.topic)
}
}
func BenchmarkSubscribers(b *testing.B) {
index := New()
index.Subscribe("path/to/my/mqtt", "client-1", 0)
index.Subscribe("path/to/+/mqtt", "client-1", 0)
index.Subscribe("something/things/stuff/+", "client-1", 0)
index.Subscribe("path/to/stuff", "client-2", 0)
index.Subscribe("#", "client-3", 0)
for n := 0; n < b.N; n++ {
index.Subscribers("path/to/testing/mqtt")
}
}
func TestIsolateParticle(t *testing.T) {
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
require.Equal(t, "to", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
require.Equal(t, "my", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
require.Equal(t, "mqtt", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("/path/", 0)
require.Equal(t, "", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 1)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 2)
require.Equal(t, "", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
require.Equal(t, "+", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
require.Equal(t, "+", particle)
require.Equal(t, false, hasNext)
}
func BenchmarkIsolateParticle(b *testing.B) {
for n := 0; n < b.N; n++ {
isolateParticle("path/to/my/mqtt", 3)
}
}
func TestMessagesPattern(t *testing.T) {
tt := []struct {
packet packets.Packet
filter string
len int
}{
{
packets.Packet{TopicName: "a/b/c/d", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/b/c/d",
1,
},
{
packets.Packet{TopicName: "a/b/c/e", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"a/+/c/+",
2,
},
{
packets.Packet{TopicName: "a/b/d/f", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/+/+/+",
3,
},
{
packets.Packet{TopicName: "q/w/e/r/t/y", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/w/e/#",
1,
},
{
packets.Packet{TopicName: "q/w/x/r/t/x", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"q/#",
2,
},
{
packets.Packet{TopicName: "asd", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"asd",
1,
},
{
packets.Packet{TopicName: "$SYS/testing", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "$SYS/test", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"+/testing",
0,
},
{
packets.Packet{TopicName: "$SYS/info", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/info",
1,
},
{
packets.Packet{TopicName: "$SYS/b", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"$SYS/#",
4,
},
{
packets.Packet{TopicName: "asd/fgh/jkl", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"#",
8,
},
{
packets.Packet{TopicName: "stuff/asdadsa/dsfdsafdsadfsa/dsfdsf/sdsadas", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}},
"stuff/#/things", // indexer will ignore trailing /things
1,
},
}
index := New()
for _, check := range tt {
index.RetainMessage(check.packet)
}
for i, check := range tt {
messages := index.Messages(check.filter)
require.Equal(t, check.len, len(messages), "Unexpected messages len at %d %s %s", i, check.filter, check.packet.TopicName)
}
}
func TestMessagesFind(t *testing.T) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "a/a", Payload: []byte{'a'}, FixedHeader: packets.FixedHeader{Retain: true}})
index.RetainMessage(packets.Packet{TopicName: "a/b", Payload: []byte{'b'}, FixedHeader: packets.FixedHeader{Retain: true}})
messages := index.Messages("a/a")
require.Equal(t, 1, len(messages))
messages = index.Messages("a/+")
require.Equal(t, 2, len(messages))
}
func BenchmarkMessages(b *testing.B) {
index := New()
index.RetainMessage(packets.Packet{TopicName: "path/to/my/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/to/another/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "path/a/some/mqtt"})
index.RetainMessage(packets.Packet{TopicName: "what/is"})
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
for n := 0; n < b.N; n++ {
index.Messages("path/to/+/mqtt")
}
}

View File

@@ -1,12 +0,0 @@
package auth
// Controller is an interface for authentication controllers.
type Controller interface {
// Authenticate authenticates a user on CONNECT and returns true if a user is
// allowed to join the server.
Authenticate(user, password []byte) bool
// ACL returns true if a user has read or write access to a given topic.
ACL(user []byte, topic string, write bool) bool
}

View File

@@ -1,31 +0,0 @@
package auth
// Allow is an auth controller which allows access to all connections and topics.
type Allow struct{}
// Auth returns true if a username and password are acceptable. Allow always
// returns true.
func (a *Allow) Authenticate(user, password []byte) bool {
return true
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Allow always returns true.
func (a *Allow) ACL(user []byte, topic string, write bool) bool {
return true
}
// Disallow is an auth controller which disallows access to all connections and topics.
type Disallow struct{}
// Auth returns true if a username and password are acceptable. Disallow always
// returns false.
func (d *Disallow) Authenticate(user, password []byte) bool {
return false
}
// ACL returns true if a user has access permissions to read or write on a topic.
// Disallow always returns false.
func (d *Disallow) ACL(user []byte, topic string, write bool) bool {
return false
}

View File

@@ -1,55 +0,0 @@
package auth
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAllowAuth(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkAllowAuth(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestAllowACL(t *testing.T) {
ac := new(Allow)
require.Equal(t, true, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkAllowACL(b *testing.B) {
ac := new(Allow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}
func TestDisallowAuth(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.Authenticate([]byte("user"), []byte("pass")))
}
func BenchmarkDisallowAuth(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.Authenticate([]byte("user"), []byte("pass"))
}
}
func TestDisallowACL(t *testing.T) {
ac := new(Disallow)
require.Equal(t, false, ac.ACL([]byte("user"), "topic", true))
}
func BenchmarkDisallowACL(b *testing.B) {
ac := new(Disallow)
for n := 0; n < b.N; n++ {
ac.ACL([]byte("user"), "pass", true)
}
}

View File

@@ -1,121 +0,0 @@
package listeners
import (
"context"
"crypto/tls"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
address string // the network address to bind to.
listen *http.Server // the http server.
end int64 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string) *HTTPStats {
return &HTTPStats{
id: id,
address: address,
config: &Config{
Auth: new(auth.Allow),
},
}
}
// SetConfig sets the configuration values for the listener config.
func (l *HTTPStats) SetConfig(config *Config) {
l.Lock()
if config != nil {
l.config = config
// If a config has been passed without an auth controller,
// it may be a mistake, so disallow all traffic.
if l.config.Auth == nil {
l.config.Auth = new(auth.Disallow)
}
}
l.Unlock()
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Listen starts listening on the listener's network address.
func (l *HTTPStats) Listen(s *system.Info) error {
l.system = s
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
}
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFunc) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info, err := json.MarshalIndent(l.system, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
return
}
w.Write(info)
}

View File

@@ -1,186 +0,0 @@
package listeners
import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.id)
require.Equal(t, testPort, l.address)
}
func BenchmarkNewHTTPStats(b *testing.B) {
for n := 0; n < b.N; n++ {
NewHTTPStats("t1", testPort)
}
}
func TestHTTPStatsSetConfig(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
})
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Allow), l.config.Auth)
// Switch to disallow on bad config set.
l.SetConfig(new(Config))
require.NotNil(t, l.config)
require.NotNil(t, l.config.Auth)
require.Equal(t, new(auth.Disallow), l.config.Auth)
}
func BenchmarkHTTPStatsSetConfig(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.SetConfig(new(Config))
}
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testPort)
require.Equal(t, "t1", l.ID())
}
func BenchmarkHTTPStatsID(b *testing.B) {
l := NewHTTPStats("t1", testPort)
for n := 0; n < b.N; n++ {
l.ID()
}
}
func TestHTTPStatsListen(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.system)
require.NotNil(t, l.listen)
require.Equal(t, testPort, l.listen.Addr)
l.listen.Close()
}
func TestHTTPStatsListenTLS(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.NoError(t, err)
require.NotNil(t, l.listen.TLSConfig)
l.listen.Close()
}
func TestHTTPStatsListenTLSInvalid(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: []byte("abcde"),
PrivateKey: testPrivateKey,
},
})
err := l.Listen(new(system.Info))
require.Error(t, err)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
resp, err := http.Get("http://localhost" + testPort)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testPort)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
l := NewHTTPStats("t1", testPort)
l.SetConfig(&Config{
Auth: new(auth.Allow),
TLS: &TLS{
Certificate: testCertificate,
PrivateKey: testPrivateKey,
},
})
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestHTTPStatsJSONHandler(t *testing.T) {
l := NewHTTPStats("t1", testPort)
err := l.Listen(&system.Info{
Version: "test",
})
require.NoError(t, err)
w := httptest.NewRecorder()
l.jsonHandler(w, nil)
resp := w.Result()
body, _ := ioutil.ReadAll(resp.Body)
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
}

View File

@@ -1,229 +0,0 @@
package listeners
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
)
func TestNew(t *testing.T) {
l := New(nil)
require.NotNil(t, l.internal)
}
func BenchmarkNewListeners(b *testing.B) {
for n := 0; n < b.N; n++ {
New(nil)
}
}
func TestAddListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
}
func BenchmarkAddListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
for n := 0; n < b.N; n++ {
l.Add(mocked)
}
}
func TestGetListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, g.ID(), "t1")
}
func BenchmarkGetListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Get("t1")
}
}
func TestLenListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func BenchmarkLenListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Len()
}
}
func TestDeleteListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, l.internal["t1"])
}
func BenchmarkDeleteListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Delete("t1")
}
}
func TestServeListener(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing)
l.Close("t1", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing)
}
func BenchmarkServeListener(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
for n := 0; n < b.N; n++ {
l.Serve("t1", MockEstablisher)
}
}
func TestServeAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing)
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing)
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing)
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.Equal(t, false, l.internal["t1"].(*MockListener).IsServing)
require.Equal(t, false, l.internal["t2"].(*MockListener).IsServing)
require.Equal(t, false, l.internal["t3"].(*MockListener).IsServing)
}
func BenchmarkServeAllListeners(b *testing.B) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1883"))
l.Add(NewMockListener("t3", ":1884"))
for n := 0; n < b.N; n++ {
l.ServeAll(MockEstablisher)
}
}
func TestCloseListener(t *testing.T) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func BenchmarkCloseListener(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}
func TestCloseAllListeners(t *testing.T) {
l := New(nil)
l.Add(NewMockListener("t1", ":1882"))
l.Add(NewMockListener("t2", ":1882"))
l.Add(NewMockListener("t3", ":1882"))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.Equal(t, true, l.internal["t1"].(*MockListener).IsServing)
require.Equal(t, true, l.internal["t2"].(*MockListener).IsServing)
require.Equal(t, true, l.internal["t3"].(*MockListener).IsServing)
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.Equal(t, true, closed["t1"])
require.Equal(t, true, closed["t2"])
require.Equal(t, true, closed["t3"])
}
func BenchmarkCloseAllListeners(b *testing.B) {
l := New(nil)
mocked := NewMockListener("t1", ":1882")
l.Add(mocked)
l.Serve("t1", MockEstablisher)
for n := 0; n < b.N; n++ {
l.internal["t1"].(*MockListener).done = make(chan bool)
l.Close("t1", MockCloser)
}
}

View File

@@ -1,89 +0,0 @@
package listeners
import (
"fmt"
"net"
"sync"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/system"
)
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
return nil
}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
Config *Config // configuration for the listener.
address string // the network address the listener binds to.
IsListening bool // indiciate the listener is listening.
IsServing bool // indicate the listener is serving.
done chan bool // indicate the listener is done.
ErrListen bool // throw an error on listen.
}
// NewMockListener returns a new instance of MockListener
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.IsServing = true
l.Unlock()
for {
select {
case <-l.done:
return
}
}
}
// SetConfig sets the configuration values of the mock listener.
func (l *MockListener) Listen(s *system.Info) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
l.IsListening = true
l.Unlock()
return nil
}
// SetConfig sets the configuration values of the mock listener.
func (l *MockListener) SetConfig(config *Config) {
l.Lock()
l.Config = config
l.Unlock()
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
l.RLock()
id := l.id
l.RUnlock()
return id
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFunc) {
l.Lock()
defer l.Unlock()
l.IsServing = false
closer(l.id)
close(l.done)
}

View File

@@ -1,79 +0,0 @@
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/server/listeners/auth"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w, new(auth.Allow))
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
}
func TestNewMockListenerListen(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, "t1", mocked.id)
require.Equal(t, ":1882", mocked.address)
require.Equal(t, false, mocked.IsListening)
err := mocked.Listen(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening)
}
func TestNewMockListenerListenFailure(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
mocked.ErrListen = true
err := mocked.Listen(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
require.Equal(t, false, mocked.IsServing)
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing)
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Listen(nil)
}
func TestMockListenerSetConfig(t *testing.T) {
mocked := NewMockListener("t1", ":1883")
mocked.SetConfig(new(Config))
require.NotNil(t, mocked.Config)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", ":1882")
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}

Some files were not shown because too many files have changed in this diff Show More