Compare commits

...

148 Commits

Author SHA1 Message Date
mochi-co
fac733fd71 Update readme v2.4.0 2023-09-08 23:25:02 +01:00
JB
add87fea2e Small fixes and cleanups (#295)
* fix typos, indicate unused returns

* Add test for publishToClient acl unauthorized

* Add Inline Client as a server option
2023-09-08 23:06:14 +01:00
thedevop
58f9fed336 Disconnect or return ack if unauthorized publish (#292)
* Ensure msg doesn't exceed subscription QoS

* Disconnect or return ack if unauthorized publish

* Disconnect or return ack if unauthorized publish

* Create new server for eery test case

---------

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-09-08 20:51:35 +01:00
werbenhu
1574443981 Another code implementation for Inline Client Subscriptions. (#284)
* Another code implementation for Inline Client Subscriptions.

* Added a few test cases.

* Changed the return value types of Server.Unsubscribe() and Subscribe() to boolean.

* Implementing the delivery of retained messages and supporting multiple callbacks per topic using different inline client IDs.

* Added validation checks for the legality of the inline client id during Subscribe and Unsubscribe.

* Added validation checks for the legality of the client during Subscribe and Unsubscribe.

* Fixed the TestServerSubscribe/invalid_client_id test case failure.

* Add Server.inlineClient and Temporarily removing test cases for better code review readability.

* Using server.inlineClient in server.InjectPacket().

* After unsubscribing, if there are other subscriptions in particle.inlineSubscriptions, particle cannot be deleted.

* Add comments to particle.inlineSubscriptions and modify to return ErrTopicFilterInvalid when the topic is invalid during subscription.

* Fixed some test case failures caused by adding inlineClient to the server.

* More test cases have been added.

* Optimization of test case code.

* Modify server.go: When used as a publisher, treat the qos of inline client-published messages as 0.

* Resolve conflict.
2023-09-08 20:45:08 +01:00
Derek Duncan
44bac0adc5 Migrate from zerolog to slog (#248)
* Begin adding new slog calls

* Fixed unit tests

* Add leveler example

* Add debug log level to Redis example

* Change location of server.Close() and add logs to example/hooks

* Begin removing references to zerolog

* Removed final references to zerolog

* Change where server.Close() occurs in main

* Change to 1.21 to remove x dependency

* Add slog

* Update references to 1.21

* Begin change of LogAttrs to standard logging interface

* Change the rest of LogAttrs to default

* Fix bad log

* Update badger.go

Changing "data" to "key" or "id" here might be more appropriate.

* Update badger.go

Changing "data" to "key" or "id" here might be more appropriate.

* Update server.go

Not checking if err is equal to nil

* Update server.go

printing information for ID or error is missing.

* Change references of err.Error() to err in slog

* Remove missed removal of Error() references for logging

---------

Co-authored-by: Derek Duncan <dduncan@atlassian.com>
Co-authored-by: Derek Duncan <derekduncan@gmail.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
Co-authored-by: werbenhu <werben@qq.com>
2023-09-06 15:21:04 +01:00
xyzj
e784c755ae add aclcheck in publishToClient (#287)
Co-authored-by: unknow <beunknow@outlook.com>
2023-09-06 15:15:08 +01:00
Derek Duncan
eafc2d91fc Fix readme example (#276)
* Update README example to better match other examples

* Fix formatting

* Fix err formatting issue

* Fix bad import in README example

---------

Co-authored-by: Derek Duncan <dduncan@atlassian.com>
2023-08-12 10:00:52 +01:00
Wind
0df69a4a4e Use JSONeq to compare JSON (#267)
* WriterSize parameter is incorrectly set

The WriterSize parameter is incorrectly set in the newClient method.

* Use JSONeq to compare JSON

This ensures that test results pass even if the field order is inconsistent.

---------

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-08-01 10:25:17 +01:00
JB
321a0514fe Update README.md 2023-07-31 21:02:14 +01:00
JB
00387593c9 Fix badges 2023-07-31 20:56:51 +01:00
JB
af78d10870 Update README.md 2023-07-31 20:55:06 +01:00
JB
30ca94e878 Update README.md 2023-07-31 13:27:06 +01:00
JB
ae3f72f677 migrate imports, copyrights, etc (#270) 2023-07-31 13:26:26 +01:00
mochi-co
9838262e66 Update server version 2023-07-20 23:08:10 +01:00
JB
ac812154e6 Allow Publish to return custom Ack error responses (#256)
* Allow publish error returns as acks

* Add Ignore Packet, tests
2023-07-20 22:52:16 +01:00
Gabriel Sagula
0234589152 fix: fix data-race in badger hook (#266)
Co-authored-by: Gabriel Sagula <gsagula@magicleap.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-07-20 22:49:34 +01:00
KE
ea97038052 method UnsubscribeClient's packet add fixedHeader (#264) 2023-07-20 22:41:49 +01:00
JB
48233334a5 Do not retain messages if retain is not available (#261)
* Do not retain messages if retain is not available

* Add Test
2023-07-16 20:41:28 +01:00
JB
33429451c8 Preference Write, Read, Deny filters in ledger (#262) 2023-07-16 20:39:43 +01:00
JB
050e24662f Retain flag should be delivered as false in v3 (#257)
* retain should be delivered as false in v3

* Forward retained flag if publish is from subscribe action
2023-07-16 20:36:45 +01:00
Ian Rose
aec29e350e Fix websocket reads for packets > 1 buffer size (#260) 2023-07-16 20:36:15 +01:00
JB
cb99d6f4bc Update README.md 2023-07-13 21:36:21 +01:00
thedevop
c77d1c0331 Ensure msg doesn't exceed subscription QoS (#253)
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-07-13 18:19:22 +01:00
Wind
6f42c3fd65 WriterSize parameter is incorrectly set (#252)
The WriterSize parameter is incorrectly set in the newClient method.
2023-07-13 18:12:11 +01:00
mochi-co
9c52292732 Small language clarification for non-english 2023-07-08 13:48:59 +01:00
mochi-co
990f308faa Update server version 2023-07-08 13:47:55 +01:00
Derek Duncan
fe0c1d15a6 Add OnSessionEstablish hook (#247)
Co-authored-by: Derek Duncan <derekduncan@gmail.com>
2023-07-08 13:09:58 +01:00
mochi-co
0648e39507 Update Readme 2023-06-19 10:28:22 +01:00
mochi-co
233a82e448 Add Healthcheck listener 2023-06-19 10:22:55 +01:00
JB
51a8d8cb54 Update README.md 2023-06-19 10:15:37 +01:00
mochi-co
23c3208310 Update SPDX annotations 2023-06-19 10:14:07 +01:00
mochi-co
23e1092cda Update Contribution Guidelines 2023-06-19 10:13:43 +01:00
mochi-co
d498576927 Update server version 2023-06-19 09:51:41 +01:00
Derek Duncan
7e14ce99b5 Add healthcheck listener (#244)
* Add healthcheck listener

* Update improper comments

---------

Co-authored-by: Derek Duncan <derekduncan@gmail.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-06-19 09:44:03 +01:00
thedevop
4db49a4b9d Fix ScanSubscribersTopicInheritanceBug (#243)
* Sub a/b should not receive msg for a/b/c...

* Add TestScanSubscribersTopicInheritanceBug test

* Ensure sharedSubscription are gathered

* Fix Unsubscribe for sharedSub and optimization

* Unsub with lower case in TestUnsubscribeShared

* Add test with # for TestScanSubscribersShared
2023-06-19 09:42:16 +01:00
mochi-co
e60b8ff0c9 Update Hooks List 2023-06-14 20:12:52 +01:00
mochi-co
b9d5dcb5f0 Update Server Version 2023-06-14 20:12:31 +01:00
JB
6d394d1fe9 Expose SendConnack, err return on OnConnect (#240) 2023-06-14 20:05:02 +01:00
thedevop
1ee2158711 Add OnRetainPublished hook (#237)
* Add OnRetainPublished hook

* Skip OnRetainPublished if publish error
2023-06-13 19:24:04 +01:00
mochi-co
af79b55b9f Update server version 2023-06-04 07:32:34 +01:00
Derek Duncan
e1a9497c25 Add retainMessage to LWT to properly handle message retention (#234)
* Add retainMessage to LWT to properly handle message retention if specified in connect

* Add will retain flag on missed test

---------

Co-authored-by: Derek Duncan <derekduncan@gmail.com>
2023-06-04 07:31:55 +01:00
mochi-co
62659e17ba Update server version 2023-05-18 20:29:56 +01:00
Hector Oliveros
7ad6dd8e1a Now when a "publish" command fails, then the publish method will throw an error (#229)
Errors in the hook when doing a publish were ignored. This caused that test cases could not be made where the publish failed and an error was thrown.

Co-authored-by: hector.oliveros@wabtec.com <hectoroliveros@MacBook-Pro-de-Hector.local>
2023-05-18 20:14:50 +01:00
thedevop
565e07747e Minimize client lock duration (#223)
* Minimize client lock duration
* Fix server option example
2023-05-18 20:01:35 +01:00
plourdedominic
6acd775a6b Fix example usage of NewHTTPStats (#231)
Co-authored-by: Dominic Plourde <plourded@amotus.ca>
2023-05-18 16:14:29 +01:00
JB
493f6c8bb0 Update README.md new benchmarks 2023-05-15 20:01:12 +01:00
mochi-co
d3785c2717 update server version 2023-05-08 11:43:46 +01:00
thedevop
52a347169a Use context to exit WriteLoop (#222)
* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Fix misspelling
2023-05-08 11:30:44 +01:00
mochi-co
797d75cb34 update server version 2023-05-06 14:32:42 +01:00
JB
5225a357e5 refactor server keepalive for hook access (#220) 2023-05-06 14:11:54 +01:00
JB
a734a0dc73 Use context to signal client open state (#218) 2023-05-06 11:55:40 +01:00
JB
6704cf7227 Add packet ID exhausted hook (#217) 2023-05-06 10:37:27 +01:00
thedevop
9233e6fd39 Expire session if SessionExpiryInterval is 0 (#216)
If SessionExpiryInterval was not set in CONNECT, SessionExpiryIntervalFlag is also not set. According to spec:
  If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed.
2023-05-06 10:12:33 +01:00
ħþ
1ca65d9631 Update codes.go (#215)
fix typo
2023-05-06 10:02:25 +01:00
ħþ
33229da885 Update codes.go (#214)
Fix typo
2023-05-06 09:59:50 +01:00
mochi-co
c274d5fd08 Update server version 2023-05-05 00:04:27 +01:00
JB
10e82f41d6 Lock on close outbound (#213) 2023-05-05 00:02:49 +01:00
JB
e6c07b2b78 Add lock to client writes (#212) 2023-05-04 23:17:12 +01:00
JB
eed3ef9606 Add OnPacketIDExhausted hook (#211) 2023-05-04 22:51:40 +01:00
JB
1ec880844d Correctly validate WillProperties (#210)
Co-authored-by: sukvojte <sukvojte@gmail.com>
2023-05-04 22:37:23 +01:00
werbenhu
4b49652a8c Update build.yml (#203) 2023-05-04 18:09:58 +01:00
mochi-co
d46e7b5bcf Protect close of nil outbound channel 2023-04-21 22:09:59 +01:00
mochi-co
17fb7dadbc Protect close of nil outbound channel 2023-04-21 22:00:27 +01:00
werben
ed7fd836e1 #78 storage hook should not execute the relevant code if the client has been reconnected (#198)
* storage hook should not execute the relevant code if the client has been reconnected #78

* add test cases for coverage decrease

add test cases for coverage decrease
2023-04-21 21:52:44 +01:00
mochi-co
605bb93c75 Move msgToPacket to storage.Message.ToPacket 2023-04-21 21:49:49 +01:00
Wind
c73ace2ea0 Simplified code (#195)
Simplify the code for the loadInflight and loadRetained methods.
Adjust the validation order of the processSubscribe method to ensure that it fails quickly if there is an error, since s.hooks. OnACLCheck generally takes a long time.

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-04-21 21:29:03 +01:00
thedevop
aac6d699da Ensure to close client WriteLoop (#193)
* Ensure client WriteLoop is closed

* Ensure to close client WriteLoop
2023-04-21 21:20:46 +01:00
Hubertus Hohl
7bd7bd5087 fix: common subscriptions issued by different clients at the same time may be lost (#186) 2023-03-11 23:17:10 +00:00
mochi-co
655bf9fdb1 Update readme 2023-03-11 23:15:51 +00:00
mochi-co
b188055c7d Update server version 2023-03-11 23:14:51 +00:00
JB
aaf1d9d4c6 Configurable client bufio reader/writer sizes (#190) 2023-03-11 23:13:28 +00:00
mochi-co
44ce819318 Update server version 2023-02-28 20:58:05 +00:00
dependabot[bot]
e4c76cc60c Bump golang.org/x/net from 0.0.0-20220927171203-f486391704dc to 0.7.0 (#182)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.0.0-20220927171203-f486391704dc to 0.7.0.
- [Release notes](https://github.com/golang/net/releases)
- [Commits](https://github.com/golang/net/commits/v0.7.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-02-28 20:57:33 +00:00
thedevop
da79faa972 Skip expire cleanup for isTakenOver session (#183)
* Skip expire cleanup for isTakenOver session

* Set prev connection to isTakenOver on CleanSession

#173.
2023-02-28 20:53:26 +00:00
JB
46babc89c8 Allow 0 byte usernames if correctly formed (#181)
* Allow 0 byte usernames if correctly formed

* Allow 0 byte usernames if correctly formed
2023-02-25 01:37:54 +00:00
JB
9b7a943888 Correctly identify and clean taken-over sessions (#180) 2023-02-25 01:24:17 +00:00
mochi-co
a909d30923 Small style fix 2023-02-22 23:33:49 +00:00
mochi-co
0851b09e4d Update server version 2023-02-22 23:06:22 +00:00
thedevop
a302c9dd88 Use *packets.Packet for outbound chan (#176)
* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan
2023-02-22 23:02:44 +00:00
mochi-co
1e8f922102 update server version 2023-02-20 18:14:57 +00:00
Hubertus Hohl
4c16e5593f fix: correct decoding of packets including Properties exceeding 127 bytes in length (#172) 2023-02-20 18:14:19 +00:00
mochi-co
49cada4fbc Update server version 2023-02-10 23:39:27 +00:00
JB
ef34510c0b Expose dropped publish messages count in sys info (#170) 2023-02-10 23:38:20 +00:00
JB
e5716caad1 Fix potential NextPacketID endless loop, expand tests (#169)
* Fix possible NextPacketID endless loop, expand tests

* Optimize NextPacketID

* Use math constants
2023-02-10 23:27:21 +00:00
thedevop
4b039cb35c Add PublishDropped metrics (#167)
* Add PublishDropped

* Add PublishDropped

* Add PublishDropped

* Update storage_test.go

* Update system.go

* Update server.go
2023-02-10 14:44:01 +00:00
JB
aac245441a No longer issue retained messages on session takeover (#166) 2023-02-09 23:57:24 +00:00
JB
bb54cc68e6 Client write buffers (#165)
* Replace fanpool with client write buffers
2023-02-09 22:34:30 +00:00
thedevop
7ba1352a60 Add Clone to system.Info (#163)
* Add Clone using atomic operations

* Add Clone using atomic operations

* Use sysinfo.Clone

* Unit test for Clone

* Add Clone using atomic operations

* Update

* Update
2023-02-09 19:07:17 +00:00
mochi-co
ca849131eb Update server version 2023-02-05 11:07:07 +00:00
Wind
ba7e534122 failed to delete inflight data (#162)
The s.hooks.OnQosPublish method needs to be called, otherwise the following s.hooks.OnQosComplete or processPuback(s.hooks.OnQosComplete) method will report a data not found error.
2023-02-05 10:53:49 +00:00
mochi-co
db760c34a5 Update server version 2023-02-04 10:57:27 +00:00
JB
ae3ee81bb4 Rename Quota methods for clarity (#159) 2023-02-04 10:53:45 +00:00
JB
c2ca02d149 Move refreshDeadline to only trigger on successful transmission (#157) 2023-02-04 10:16:05 +00:00
Jeroen Rinzema
77a64d9c87 Include a listener accepting an existing net.Listener (#155) 2023-02-04 10:10:10 +00:00
Wind
8dec9cc962 invalid config type provided (#152)
* invalid config type provided

examples/persistence/bolt/main.go: invalid config type provided

* fixed ErrReceiveMaximum(receive maximum exceeded)

No quotas of the inflight is set in the readStore method, so each quota is equal to 0. The inheritClientSession method overrides the quotas of the new client inflight, so the processPublish method reports an ErrReceiveMaximum and disconnects the client.

* reset receive quota

receive quota should be reset across connections (as specified in the spec).
2023-02-04 10:06:26 +00:00
mochi-co
f90e52328d Update server version 2023-01-16 20:08:55 +00:00
JB
50aae47618 Publish retained messages only after connack (#147) 2023-01-16 19:50:01 +00:00
JB
0d79f2d63b Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
2023-01-16 19:49:36 +00:00
JB
300152413c Ignore retain as published v3 (#142)
* Optimise Capabilities struct alignment

* Only use RetainAsPublished for v5 clients
2023-01-13 23:38:49 +00:00
mochi-co
0de1d731db Update version number 2023-01-10 00:01:21 +00:00
JB
80746abc52 Use correct connack return codes for MQTTv3 (#140) 2023-01-10 00:00:43 +00:00
mochi-co
a73cf4ca0e Update server version 2023-01-09 23:08:49 +00:00
mochi-co
bc549ee7ed Fix example imports 2023-01-09 22:52:24 +00:00
mochi-co
c464b46713 export client.Net.Conn for external use 2023-01-09 22:49:40 +00:00
mochi-co
05ce56008c Small code improvements 2023-01-09 22:49:20 +00:00
JB
8254cb0cbc Make hooks safe for concurrency (#139)
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
2023-01-09 22:41:44 +00:00
mochi-co
4ae58b79e3 Update server version 2023-01-07 20:13:48 +00:00
thedevop
b895d688e0 Change inline check order (#133) 2023-01-07 20:02:05 +00:00
mochi-co
a600cd4ead fix grammar on Closed method doc 2023-01-07 17:57:04 +00:00
mochi-co
cdb44990cf Update version number 2023-01-07 17:30:58 +00:00
mochi-co
2d9c128111 Refactor stored subscription value assignments 2023-01-07 17:30:30 +00:00
mochi-co
a0d5bdb39f Fix Typos 2023-01-07 17:30:01 +00:00
Wind
4ebcef3cb6 Save subscription properties for mqttv5 (#131)
* Update redis.go

Save the subscription properties for mqqtv5

* Update badger.go

Save the subscription properties for mqqtv5

* Update bolt.go

Save the subscription properties for mqqtv5

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-01-07 17:24:23 +00:00
thedevop
fb8d4720d7 Add Client Closed (#130)
* Add Client Closed
* Add Client Closed
* Update clients_test.go
2023-01-07 17:14:51 +00:00
JB
4080c89127 Update README.md 2022-12-21 21:00:53 +00:00
JB
1b67e6f3f6 Update README.md 2022-12-21 20:58:25 +00:00
mochi-co
1adb02e087 Update readme and server version 2022-12-21 20:47:58 +00:00
JB
4d4140aa99 Connect ReturnResponseInfo only applies to Connack values (#128) 2022-12-21 20:37:08 +00:00
JB
e31840a37d Optimize inflight expiry (#127)
* Small formatting/style changes for filter existed

* Use OnQoSDropped hook instead of onInflightExpired
2022-12-21 19:44:25 +00:00
JB
7d2e16f2d3 Merge pull request #123 from wind-c/master
Variable existed in the method processSubscribe is unstable
2022-12-21 11:41:14 +00:00
JB
92cd935a16 Merge branch 'master' into master 2022-12-21 11:38:28 +00:00
JB
25ce27ce2d Merge pull request #124 from zgwit/master
Add unix socket listener
2022-12-21 11:28:23 +00:00
jason
527d084a4b Add unix socket listener 2022-12-20 23:02:59 +08:00
Wind
bb9f937bb0 Variable existed in the method processSubscribe is unstable
The variable existed can be changed repeatedly within a for loop. An array variable must be used to record the subscription of each filter.
2022-12-18 13:46:06 +08:00
Wind
511fe88684 Merge branch 'mochi-co:master' into master 2022-12-17 12:33:09 +08:00
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
“Wind”
d06f47f4b9 Add the OnUnsubscribed hook to the unsubscribeClient method
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-17 00:40:06 +08:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
JB
d8f28cb843 Enforce server max packet (#121)
* Enforce Server Maximum Packet Size on client read
* Fix tests
2022-12-16 00:30:23 +00:00
JB
88861c219d Merge pull request #116 from tommyminds/bugfix/ws_malformed_package
Fix websocket malformed packet bug
2022-12-15 18:21:53 +00:00
JB
7ba6cf28d9 Merge branch 'master' into bugfix/ws_malformed_package 2022-12-15 18:21:33 +00:00
JB
c174cfdc6b Merge pull request #119 from mochi-co/fix-on-published
Fix mis-typed onpublished hook, update version, fanpool defaults
2022-12-15 18:21:19 +00:00
mochi-co
4f198a99dd Fix mis-typed onpublished hook, update version, fanpool defaults 2022-12-15 18:19:02 +00:00
Tommy Maintz
2a9c9fcc40 Fix websocket malformed packet bug 2022-12-14 21:41:33 +01:00
JB
835a85c8bf Update README.md 2022-12-12 11:44:36 +00:00
mochi-co
fe5d9ffa61 Simplify Client construction, add NewClient method to Server, add Publish convenience method 2022-12-12 11:37:19 +00:00
mochi-co
aac186dcc1 Add newline for godoc formatting 2022-12-11 22:25:21 +00:00
JB
42931f332f Update badges to use v2 references 2022-12-11 21:44:44 +00:00
mochi-co
8a04648c09 Cleanup godoc formatting 2022-12-11 21:38:01 +00:00
JB
854c033fb6 Update README.md 2022-12-11 12:21:25 +00:00
mochi-co
74ed8cd046 Update go mod and imports to v2 2022-12-11 11:50:44 +00:00
JB
be164fa715 Update go.mod 2022-12-11 11:43:55 +00:00
JB
4287955161 Update go.mod 2022-12-11 11:42:43 +00:00
JB
bbf08ff496 Update README.md 2022-12-10 22:51:01 +00:00
JB
c38201ff8b Update README.md 2022-12-10 22:32:53 +00:00
JB
f8b4ff5c0d Update README.md 2022-12-10 22:29:11 +00:00
JB
661e23e051 Update README.md 2022-12-10 22:11:39 +00:00
322 changed files with 27002 additions and 12790 deletions

View File

@@ -7,37 +7,39 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.21
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
name: Test with Coverage
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.21'
- name: Check out code
uses: actions/checkout@v3
- name: Install dependencies
run: |
go mod download
- name: Run Unit tests
run: |
go test -race -covermode atomic -coverprofile=covprofile ./...
- name: Install goveralls
run: go install github.com/mattn/goveralls@latest
- name: Send coverage
env:
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: goveralls -coverprofile=covprofile -service=github

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
cmd/mqtt
.DS_Store
*.db
.idea

View File

@@ -1,4 +1,4 @@
FROM golang:1.19.0-alpine3.15 AS builder
FROM golang:1.21.0-alpine3.18 AS builder
RUN apk update
RUN apk add git

View File

@@ -1,7 +1,8 @@
The MIT License (MIT)
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
Copyright (c) 2023 Mochi-MQTT Organisation
Copyright (c) 2019, 2022, 2023 Jonathan Blake (mochi-co)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

293
README.md
View File

@@ -1,26 +1,26 @@
# Mochi-MQTT Server
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
![build status](https://github.com/mochi-mqtt/server/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-mqtt/server/badge.svg?branch=master&v2)](https://coveralls.io/github/mochi-mqtt/server?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-mqtt/server)](https://goreportcard.com/report/github.com/mochi-mqtt/server/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-mqtt/server.svg)](https://pkg.go.dev/github.com/mochi-mqtt/server/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-mqtt/server/issues)
</p>
# Mochi MQTT Broker
## The fully compliant, embeddable high-performance Go MQTT v5 server (v5 | v3.1.1 | v3.0)
🎆 **mochi-co/mqtt is now part of the new mochi-mqtt organisation.** [Read about this announcement here.](https://github.com/orgs/mochi-mqtt/discussions/271)
### Mochi-MQTT is a fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker/server
> #### 📦 💬 See Github Discussions for discussions about releases
> Ongoing discussion about current and future releases can be found at https://github.com/mochi-co/mqtt/discussions
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
### What is MQTT?
#### What is MQTT?
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
## What's new in Version 2.0.0?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class citizen.
#### Mochi-MQTT Features
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
- User and MQTTv5 Packet Properties
@@ -35,26 +35,27 @@ Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learn
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
- Developer-centric:
- Most core broker code is now exported and accessible, for total developer control.
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Full-featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
### Compatibility Notes
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
#### When is this repo updated?
Unless it's a critical issue, new releases typically go out over the weekend.
## Roadmap
- Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks!
- Please [open an issue](https://github.com/mochi-mqtt/server/issues) to request new features or event hooks!
- Cluster support.
- Enhanced Metrics support.
- File-based server configuration (supporting docker).
@@ -75,33 +76,55 @@ A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Webso
docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
```
_More substantial docker support is being discussed [here](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545) and [here](https://github.com/orgs/mochi-mqtt/discussions/209). Please join the discussion if you use Mochi-MQTT in this environment._
## Developing with Mochi MQTT
### Importing as a package
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"github.com/mochi-co/mqtt"
"log"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
// Create signals channel to run server until interrupted
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Run server until interrupted
<-done
// Cleanup
}
```
@@ -110,10 +133,16 @@ Examples of running the broker with various configurations can be found in the [
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
| Listener | Usage |
|------------------------------|----------------------------------------------------------------------------------------------|
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewNet | A net.Listener listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
| listeners.NewHTTPHealthCheck | An HTTP healthcheck listener to provide health check responses for e.g. cloud infrastructure |
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
A `*listeners.Config` may be passed to configure TLS.
@@ -130,11 +159,14 @@ server := mqtt.New(&mqtt.Options{
ObscureNotAuthorized: true,
},
},
ClientNetWriteBufferSize: 4096,
ClientNetReadBufferSize: 4096,
SysTopicResendInterval: 10,
InlineClient: false,
})
```
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
## Event Hooks
@@ -142,16 +174,16 @@ A universal event hooks system allows developers to hook into various parts of t
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
| Type | Import | Info |
| -- | -- | -- |
| Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
| Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
| Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
| Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
| Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
| Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
| Type | Import | Info |
|----------------|--------------------------------------------------------------------------|----------------------------------------------------------------------------|
| Access Control | [mochi-mqtt/server/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
| Access Control | [mochi-mqtt/server/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
| Persistence | [mochi-mqtt/server/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
| Persistence | [mochi-mqtt/server/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
| Persistence | [mochi-mqtt/server/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
| Debugging | [mochi-mqtt/server/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know!
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-mqtt/server/issues) and let everyone know!
### Access Control
#### Allow Hook
@@ -219,7 +251,7 @@ err := server.AddHook(new(auth.Hook), &auth.Options{
The ledger can also be stored as JSON or YAML and loaded using the Data field:
```go
err = server.AddHook(new(auth.Hook), &auth.Options{
err := server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice: yaml or json
})
```
@@ -256,60 +288,99 @@ For more information on how the badger hook works, or how to use it, see the [ex
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
## Developing with Event Hooks
Many hooks are available for interacting with the broker and client lifecycle.
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
| Function | Usage |
| -------------------------- | -- |
| OnStarted | Called when the server has successfully started.|
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
| Function | Usage |
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| OnStarted | Called when the server has successfully started. |
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects, may return an error or packet code to halt the client connection process. |
| OnSessionEstablish | Called immediately after a new client connects and authenticates and immediately before the session is established and CONNACK is sent. |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification. |
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnRetainPublished | Called then a retained message is published to a client. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Packet Injection
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
### Inline Client (v2.4.0+)
It's now possible to subscribe and publish to topics directly from the embedding code, by using the `inline client` feature. The Inline Client is an embedded client which operates as part of the server, and can be enabled in the server options:
```go
server := mqtt.New(&mqtt.Options{
InlineClient: true,
})
```
Once enabled, you will be able to use the `server.Publish`, `server.Subscribe`, and `server.Unsubscribe` methods to issue and received messages from broker-adjacent code.
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
> See [direct examples](examples/direct/main.go) for real-life usage examples.
#### Inline Publish
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
```go
cl := mqtt.NewInlineClient("inline", "local")
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
```
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
#### Inline Subscribe
To subscribe to a topic filter from within the embedding application, you can use the `server.Subscribe(filter string, subscriptionId int, handler InlineSubFn) error` method with a callback function. Note that only QoS 0 is supported for inline subscriptions. If you wish to have multiple callbacks for the same filter, you can use the MQTTv5 `subscriptionId` property to differentiate.
```go
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
}
server.Subscribe("direct/#", 1, callbackFn)
```
#### Inline Unsubscribe
You may wish to unsubscribe if you have subscribed to a filter using the inline client. You can do this easily with the `server.Unsubscribe(filter string, subscriptionId int) error` method:
```go
server.Unsubscribe("direct/#", 1)
```
### Packet Injection
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client.
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
Most of the time you'll want to use the Inline Client described above, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics. In this case, you can create an inline client from scratch which will behave the same as the built-in inline client.
```go
cl := server.NewClient(nil, "local", "inline", true)
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
@@ -323,6 +394,7 @@ server.InjectPacket(cl, packets.Packet{
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
@@ -337,44 +409,59 @@ You can check the broker against the [Paho Interoperability Test](https://github
## Performance Benchmarks
Mochi MQTT performance is comparable with popular brokers such as Mosquitto, Mosca, and VerneMQ.
Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQX, and others.
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
> Benchmarks are provided as a general performance expectation guideline only.
> Benchmarks are provided as a general performance expectation guideline only. Comparisons are performed using out-of-the-box default configurations.
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
## Contribution Guidelines
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-mqtt/server/issues) to report a bug, ask a question, or make a feature request. If you open a pull request, please try to follow the following guidelines:
- Try to maintain test coverage where reasonably possible.
- Clearly state what the PR does and why.
- Please remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution.
[SPDX Annotations](https://spdx.dev) are used to clearly indicate the license, copyright, and contributions of each file in a machine-readable format. If you are adding a new file to the repository, please ensure it has the following SPDX header:
```go
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt
// SPDX-FileContributor: Your name or alias <optional@email.address>
package name
```
Please ensure to add a new `SPDX-FileContributor` line for each contributor to the file. Refer to other files for examples. Please remember to do this, your contributions to this project are valuable and appreciated - it's important to receive credit!
## Stargazers over time 🥰
[![Stargazers over time](https://starchart.cc/mochi-co/mqtt.svg)](https://starchart.cc/mochi-co/mqtt)
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
## Contributions
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.
[![Stargazers over time](https://starchart.cc/mochi-mqtt/server.svg)](https://starchart.cc/mochi-mqtt/server)
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-mqtt/server/issues)

View File

@@ -1,11 +1,13 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -15,7 +17,7 @@ import (
"github.com/rs/xid"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2/packets"
)
const (
@@ -86,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -97,7 +99,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the clinet
Net ClientConnection // network connection state of the client
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
@@ -105,11 +107,11 @@ type Client struct {
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // client is an inline programmetic client
Inline bool // if true, the client is the built-in 'inline' embedded client
}
// ClientProperties contains the properties which define the client behaviour.
@@ -132,32 +134,37 @@ type Will struct {
Retain bool // -
}
// State tracks the state of the client.
// ClientState tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, o *ops) *Client {
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
Net: ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Remote: c.RemoteAddr().String(),
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
open: ctx,
cancelOpen: cancel,
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -165,43 +172,33 @@ func NewClient(c net.Conn, o *ops) *Client {
ops: o,
}
cl.refreshDeadline(cl.State.keepalive)
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetWriteBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
return cl
}
// NewInlineClient returns a client used when publishing from the embedding system.
func NewInlineClient(id, remote string) *Client {
return &Client{
ID: id,
Net: ClientConnection{
Remote: remote,
Inline: true,
},
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
}
}
// newClientStub returns an instance of Client with minimal initializations, such as
// restoring client data from a db. In particular, the client is marked as offline (done).
func newClientStub() *Client {
return &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(0),
done: 1,
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for {
select {
case pk := <-cl.State.outbound:
if err := cl.WritePacket(*pk); err != nil {
// TODO : Figure out what to do with error
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
}
atomic.AddInt32(&cl.State.outboundQty, -1)
case <-cl.State.open.Done():
return
}
}
}
@@ -214,9 +211,9 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
@@ -225,11 +222,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
@@ -247,19 +239,17 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -267,28 +257,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
started := i
overflowed := false
for {
if i >= 65535 {
overflowed = true
i = 1
} else {
i++
}
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
break
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
@@ -302,7 +294,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
@@ -319,18 +311,19 @@ func (cl *Client) ResendInflightMessages(force bool) error {
return nil
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
@@ -340,11 +333,11 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
@@ -365,20 +358,20 @@ func (cl *Client) Read(packetHandler ReadFn) error {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.cancelOpen != nil {
cl.State.cancelOpen()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -391,6 +384,11 @@ func (cl *Client) StopCause() error {
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
@@ -413,6 +411,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
@@ -480,15 +482,14 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
if cl.Net.Conn == nil {
return nil
}
defer cl.refreshDeadline(cl.State.keepalive)
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
@@ -502,8 +503,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
@@ -553,7 +554,11 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
n, err := func() (int64, error) {
cl.Lock()
defer cl.Unlock()
return nb.WriteTo(cl.Net.Conn)
}()
if err != nil {
return err
}

View File

@@ -1,9 +1,11 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -11,8 +13,8 @@ import (
"testing"
"time"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
)
@@ -21,16 +23,20 @@ const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newClient() (cl *Client, r net.Conn, w net.Conn) {
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = NewClient(w, &ops{
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
log: logger,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
@@ -41,6 +47,9 @@ func newClient() (cl *Client, r net.Conn, w net.Conn) {
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
@@ -106,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -118,34 +127,23 @@ func TestClientsGetByListener(t *testing.T) {
}
func TestNewClient(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Nil(t, cl.StopCause())
}
func TestNewClientStub(t *testing.T) {
cl := newClientStub()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
}
func TestNewInlineClient(t *testing.T) {
cl := NewInlineClient("inline", "local")
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
require.Equal(t, "inline", cl.ID)
require.Equal(t, "local", cl.Net.Remote)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -167,7 +165,7 @@ func TestClientParseConnect(t *testing.T) {
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
@@ -175,14 +173,14 @@ func TestClientParseConnect(t *testing.T) {
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
@@ -207,13 +205,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
@@ -225,7 +223,7 @@ func TestClientNextPacketID(t *testing.T) {
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
@@ -248,33 +246,37 @@ func TestClientNextPacketIDInUse(t *testing.T) {
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID - 1
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newClient()
cl.State.packetID = uint32(65534)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(65535), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
@@ -284,13 +286,15 @@ func TestClientClearInflights(t *testing.T) {
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newClient()
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
@@ -299,7 +303,7 @@ func TestClientResendInflightMessages(t *testing.T) {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
w.Close()
_ = w.Close()
}()
buf, err := io.ReadAll(r)
@@ -310,8 +314,8 @@ func TestClientResendInflightMessages(t *testing.T) {
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newClient()
r.Close()
cl, r, _ := newTestClient()
_ = r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
@@ -322,24 +326,24 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0x00})
r.Close()
_, _ = r.Write([]byte{packets.Connect << 4, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -349,12 +353,12 @@ func TestClientReadFixedHeader(t *testing.T) {
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
r.Close()
_, _ = r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -362,12 +366,28 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
require.Error(t, err)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newClient()
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Close()
_, _ = r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
_ = r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -377,12 +397,12 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
r.Close()
_, _ = r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -391,10 +411,10 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
@@ -404,7 +424,7 @@ func TestClientReadOK(t *testing.T) {
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
_ = r.Close()
}()
var pks []packets.Packet
@@ -445,9 +465,9 @@ func TestClientReadOK(t *testing.T) {
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.cancelOpen()
o := make(chan error)
go func() {
@@ -460,21 +480,29 @@ func TestClientReadDone(t *testing.T) {
}
func TestClientStop(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
r.Close()
_ = r.Close()
}()
cl.Net.bconn = nil
@@ -485,16 +513,16 @@ func TestClientReadFixedHeaderError(t *testing.T) {
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
_ = r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
@@ -505,16 +533,16 @@ func TestClientReadReadHandlerErr(t *testing.T) {
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
fh := new(packets.FixedHeader)
@@ -537,7 +565,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
@@ -545,7 +573,7 @@ func TestClientReadPacket(t *testing.T) {
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
r.Write(tt.RawBytes)
_, _ = r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
@@ -570,9 +598,17 @@ func TestClientReadPacket(t *testing.T) {
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
_ = cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
@@ -588,7 +624,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
_ = cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -612,7 +648,7 @@ func TestClientWritePacket(t *testing.T) {
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
@@ -621,16 +657,16 @@ func TestWriteClientOversizePacket(t *testing.T) {
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
@@ -641,16 +677,16 @@ func TestClientReadPacketReadingError(t *testing.T) {
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newClient()
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
_, _ = r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
_ = r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
@@ -660,7 +696,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
@@ -669,15 +705,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newClient()
cl.Net.conn.Close()
cl, _, _ := newTestClient()
_ = cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -10,9 +11,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -58,7 +59,8 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -9,9 +10,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -76,7 +77,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -10,9 +11,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -58,7 +59,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,19 +1,20 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/hooks/debug"
"github.com/mochi-co/mqtt/listeners"
"github.com/rs/zerolog"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/debug"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -26,17 +27,21 @@ func main() {
}()
server := mqtt.New(nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
level := new(slog.LevelVar)
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
level.Set(slog.LevelDebug)
err := server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
ShowPacketData: true,
})
err = server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
@@ -55,7 +60,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

83
examples/direct/main.go Normal file
View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-mqtt/server/v2/hooks/auth"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(&mqtt.Options{
InlineClient: true, // you must enable inline client to use direct publishing and subscribing.
})
_ = server.AddHook(new(auth.AllowHook), nil)
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Demonstration of using an inline client to directly subscribe to a topic and receive a message when
// that subscription is activated. The inline subscription method uses the same internal subscription logic
// as used for external (normal) clients.
go func() {
// Inline subscriptions can also receive retained messages on subscription.
_ = server.Publish("direct/retained", []byte("retained message"), true, 0)
_ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0)
// Subscribe to a filter and handle any received messages via a callback function.
callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) {
server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload))
}
server.Log.Info("inline client subscribing")
_ = server.Subscribe("direct/#", 1, callbackFn)
_ = server.Subscribe("direct/#", 2, callbackFn)
}()
// There is a shorthand convenience function, Publish, for easily sending publish packets if you are not
// concerned with creating your own packets. If you want to have more control over your packets, you can
//directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage.
go func() {
for range time.Tick(time.Second * 3) {
err := server.Publish("direct/publish", []byte("scheduled message"), false, 0)
if err != nil {
server.Log.Error("server.Publish", "error", err)
}
server.Log.Info("main.go issued direct message to direct/publish")
}
}()
go func() {
time.Sleep(time.Second * 10)
// Unsubscribe from the same filter to stop receiving messages.
server.Log.Info("inline client unsubscribing")
_ = server.Unsubscribe("direct/#", 1)
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,20 +1,22 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
"github.com/mochi-co/mqtt/packets"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
"github.com/mochi-mqtt/server/v2/packets"
)
func main() {
@@ -51,23 +53,38 @@ func main() {
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := mqtt.NewInlineClient("inline", "local")
for range time.Tick(time.Second * 10) {
server.InjectPacket(cl, packets.Packet{
cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 1) {
err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
Payload: []byte("injected scheduled message"),
})
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
if err != nil {
server.Log.Error("server.InjectPacket", "error", err)
}
server.Log.Info("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error("server.Publish", "error", err)
}
server.Log.Info("main.go issued direct message to direct/publish")
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}
type ExampleHook struct {
@@ -90,38 +107,44 @@ func (h *ExampleHook) Provides(b byte) bool {
}
func (h *ExampleHook) Init(config any) error {
h.Log.Info().Msg("initialised")
h.Log.Info("initialised")
return nil
}
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
h.Log.Info("client connected", "client", cl.ID)
return nil
}
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
if err != nil {
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err)
} else {
h.Log.Info("client disconnected", "client", cl.ID, "expire", expire)
}
}
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters)
}
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters)
}
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload))
pkx := pk
if string(pk.Payload) == "hello" {
pkx.Payload = []byte("hello world")
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload))
}
return pkx, nil
}
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload))
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -10,9 +11,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/listeners"
"github.com/mochi-co/mqtt/packets"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/listeners"
"github.com/mochi-mqtt/server/v2/packets"
)
func main() {
@@ -25,10 +26,9 @@ func main() {
}()
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
@@ -45,9 +45,9 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}
type pahoAuthHook struct {
@@ -61,6 +61,7 @@ func (h *pahoAuthHook) ID() string {
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnConnect,
mqtt.OnACLCheck,
}, []byte{b})
}
@@ -72,3 +73,12 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet)
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}
func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
// Handle paho test_server_keep_alive
if pk.Connect.Keepalive == 120 && pk.Connect.Clean {
cl.State.Keepalive = 60
cl.State.ServerKeepalive = true
}
return nil
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -9,10 +10,10 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/hooks/storage/badger"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -51,8 +52,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -10,10 +11,10 @@ import (
"syscall"
"time"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/hooks/storage/bolt"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
"github.com/mochi-mqtt/server/v2/listeners"
"go.etcd.io/bbolt"
)
@@ -29,12 +30,15 @@ func main() {
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), bolt.Options{
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
@@ -50,7 +54,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,19 +1,20 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/hooks/storage/redis"
"github.com/mochi-co/mqtt/listeners"
"github.com/rs/zerolog"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
"github.com/mochi-mqtt/server/v2/listeners"
rv8 "github.com/go-redis/redis/v8"
)
@@ -29,8 +30,12 @@ func main() {
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
level := new(slog.LevelVar)
server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
level.Set(slog.LevelDebug)
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
@@ -57,8 +62,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -9,9 +10,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -51,7 +52,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -10,9 +11,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
var (
@@ -96,7 +97,7 @@ func main() {
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
TLSConfig: tlsConfig,
}, nil)
}, server.Info)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)
@@ -110,7 +111,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
@@ -9,9 +10,9 @@ import (
"os/signal"
"syscall"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/auth"
"github.com/mochi-co/mqtt/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/listeners"
)
func main() {
@@ -40,7 +41,7 @@ func main() {
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -1,100 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

View File

@@ -1,88 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

15
go.mod
View File

@@ -1,17 +1,15 @@
module github.com/mochi-co/mqtt
module github.com/mochi-mqtt/server/v2
go 1.19
go 1.21
require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/rs/xid v1.4.0
github.com/rs/zerolog v1.28.0
github.com/stretchr/testify v1.7.1
github.com/timshannon/badgerhold v1.0.0
go.etcd.io/bbolt v1.3.5
@@ -21,6 +19,7 @@ require (
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
@@ -28,13 +27,11 @@ require (
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/golang/protobuf v1.5.0 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/google/go-cmp v0.5.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.5.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
)

32
go.sum
View File

@@ -23,7 +23,6 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -38,9 +37,9 @@ github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
@@ -48,8 +47,9 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
@@ -62,15 +62,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@@ -79,8 +78,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
@@ -109,24 +106,21 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
@@ -136,8 +130,10 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

261
hooks.go
View File

@@ -1,19 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, thedevop, dgduncan
package mqtt
import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/rs/zerolog"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
)
const (
@@ -24,6 +24,7 @@ const (
OnConnectAuthenticate
OnACLCheck
OnConnect
OnSessionEstablish
OnSessionEstablished
OnDisconnect
OnAuthPacket
@@ -38,15 +39,17 @@ const (
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnRetainPublished
OnQosPublish
OnQosComplete
OnQosDropped
OnPacketIDExhausted
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
@@ -66,13 +69,14 @@ type Hook interface {
Provides(b byte) bool
Init(config any) error
Stop() error
SetOpts(l *zerolog.Logger, o *HookOptions)
SetOpts(l *slog.Logger, o *HookOptions)
OnStarted()
OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet)
OnConnect(cl *Client, pk packets.Packet) error
OnSessionEstablish(cl *Client, pk packets.Packet)
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
@@ -87,15 +91,17 @@ type Hook interface {
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnRetainPublished(cl *Client, pk packets.Packet)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnPacketIDExhausted(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
@@ -110,11 +116,11 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
Log *slog.Logger // a logger for the hook (from the server)
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -124,7 +130,7 @@ func (h *Hooks) Len() int64 {
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
@@ -139,29 +145,42 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
for _, hook := range h.GetAll() {
h.Log.Info("stopping hook", "hook", hook.ID())
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
h.Log.Debug("problem stopping hook", "error", err, "hook", hook.ID())
}
h.wg.Done()
@@ -173,7 +192,7 @@ func (h *Hooks) Stop() {
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
@@ -182,7 +201,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
@@ -191,25 +210,39 @@ func (h *Hooks) OnStarted() {
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
err := hook.OnConnect(cl, pk)
if err != nil {
return err
}
}
}
return nil
}
// OnSessionEstablish is called right after a new client connects and authenticates and right before
// the session is established and CONNACK is sent.
func (h *Hooks) OnSessionEstablish(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablish) {
hook.OnSessionEstablish(cl, pk)
}
}
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
@@ -218,7 +251,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
@@ -228,11 +261,11 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
h.Log.Debug("packet rejected", "hook", hook.ID(), "packet", pkx)
return pk, err
} else if err != nil {
continue
@@ -249,7 +282,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
@@ -265,7 +298,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
@@ -276,7 +309,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
@@ -286,7 +319,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
@@ -298,7 +331,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
@@ -308,7 +341,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
@@ -320,7 +353,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
@@ -333,7 +366,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
@@ -343,28 +376,35 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnMessage
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
if err != nil {
if errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug("publish packet rejected",
"error", err,
"hook", hook.ID(),
"packet", pkx)
return pk, err
}
h.Log.Error("publish packet error",
"error", err,
"hook", hook.ID(),
"packet", pkx)
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
@@ -374,27 +414,46 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnRetainPublished is called when a retained message is published.
func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainPublished) {
hook.OnRetainPublished(cl, pk)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
@@ -405,7 +464,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
@@ -413,26 +472,39 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
// assign to a packet.
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketIDExhausted) {
hook.OnPacketIDExhausted(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
h.Log.Error("parse will error",
"error", err,
"hook", hook.ID(),
"will", will)
continue
}
will = mlwt
@@ -444,7 +516,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
@@ -453,7 +525,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
@@ -462,7 +534,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
@@ -472,11 +544,11 @@ func (h *Hooks) OnRetainedExpired(filter string) {
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
h.Log.Error("failed to load clients", "error", err, "hook", hook.ID())
return v, err
}
@@ -492,11 +564,11 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
h.Log.Error("failed to load subscriptions", "error", err, "hook", hook.ID())
return v, err
}
@@ -512,11 +584,11 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
h.Log.Error("failed to load inflight messages", "error", err, "hook", hook.ID())
return v, err
}
@@ -532,11 +604,11 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
h.Log.Error("failed to load retained messages", "error", err, "hook", hook.ID())
return v, err
}
@@ -551,11 +623,11 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
h.Log.Error("failed to load $SYS info", "error", err, "hook", hook.ID())
return v, err
}
@@ -573,7 +645,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
@@ -589,7 +661,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
@@ -600,24 +672,11 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Hook
Log *zerolog.Logger
Log *slog.Logger
Opts *HookOptions
}
@@ -640,12 +699,12 @@ func (h *HookBase) Init(config any) error {
// SetOpts is called by the server to propagate internal values and generally should
// not be called manually.
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
func (h *HookBase) SetOpts(l *slog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
// Stop is called to gracefully shutdown the hook.
// Stop is called to gracefully shut down the hook.
func (h *HookBase) Stop() error {
return nil
}
@@ -670,7 +729,13 @@ func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
return nil
}
// OnSessionEstablish is called right after a new client connects and authenticates and right before
// the session is established and CONNACK is sent.
func (h *HookBase) OnSessionEstablish(cl *Client, pk packets.Packet) {}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
@@ -728,9 +793,15 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnRetainPublished is called when a retained message is published.
func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
@@ -740,6 +811,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
@@ -754,9 +828,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return

View File

@@ -1,13 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
)
// AllowHook is an authentication hook which allows connection access

View File

@@ -1,13 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/stretchr/testify/require"
)

View File

@@ -1,13 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
)
// Options contains the configuration/rules data for the auth ledger.
@@ -66,10 +67,9 @@ func (h *Hook) Init(config any) error {
}
}
h.Log.Info().
Int("authentication", len(h.ledger.Auth)).
Int("acl", len(h.ledger.ACL)).
Msg("loaded auth rules")
h.Log.Info("loaded auth rules",
"authentication", len(h.ledger.Auth),
"acl", len(h.ledger.ACL))
return nil
}
@@ -81,11 +81,9 @@ func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
h.Log.Info().
Str("username", string(pk.Connect.Username)).
Str("remote", cl.Net.Remote).
Msg("client failed authentication check")
h.Log.Info("client failed authentication check",
"username", string(pk.Connect.Username),
"remote", cl.Net.Remote)
return false
}
@@ -96,11 +94,10 @@ func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return true
}
h.Log.Debug().
Str("client", cl.ID).
Str("username", string(cl.Properties.Username)).
Str("topic", topic).
Msg("client failed allowed ACL check")
h.Log.Debug("client failed allowed ACL check",
"client", cl.ID,
"username", string(cl.Properties.Username),
"topic", topic)
return false
}

View File

@@ -1,19 +1,20 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"log/slog"
"os"
"testing"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
"github.com/rs/zerolog"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/stretchr/testify/require"
)
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
// func teardown(t *testing.T, path string, h *Hook) {
// h.Stop()
@@ -33,7 +34,7 @@ func TestBasicProvides(t *testing.T) {
func TestBasicInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
@@ -41,7 +42,7 @@ func TestBasicInitBadConfig(t *testing.T) {
func TestBasicInitDefaultConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
@@ -49,7 +50,7 @@ func TestBasicInitDefaultConfig(t *testing.T) {
func TestBasicInitWithLedgerPointer(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
ln := &Ledger{
Auth: []AuthRule{
@@ -78,7 +79,7 @@ func TestBasicInitWithLedgerPointer(t *testing.T) {
func TestBasicInitWithLedgerJSON(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
@@ -92,7 +93,7 @@ func TestBasicInitWithLedgerJSON(t *testing.T) {
func TestBasicInitWithLedgerYAML(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
@@ -106,7 +107,7 @@ func TestBasicInitWithLedgerYAML(t *testing.T) {
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
@@ -118,7 +119,7 @@ func TestBasicInitWithLedgerBadDAta(t *testing.T) {
func TestOnConnectAuthenticate(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
@@ -157,7 +158,7 @@ func TestOnConnectAuthenticate(t *testing.T) {
func TestOnACL(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
@@ -8,9 +9,10 @@ import (
"strings"
"sync"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
"gopkg.in/yaml.v3"
"github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
)
const (
@@ -78,8 +80,8 @@ func (r RString) Matches(a string) bool {
}
// FilterMatches returns true if a filter matches a topic rule.
func (f RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(f), a)
func (r RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(r), a)
return ok
}
@@ -159,7 +161,7 @@ func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the write bool.
// a specific filter or topic respectively, based on the `write` bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
@@ -187,17 +189,31 @@ func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok boo
return n, true
}
for filter, access := range rule.Filters {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
if write {
for filter, access := range rule.Filters {
if access == WriteOnly || access == ReadWrite {
if filter.FilterMatches(topic) {
return n, true
}
}
}
}
if !write {
for filter, access := range rule.Filters {
if access == ReadOnly || access == ReadWrite {
if filter.FilterMatches(topic) {
return n, true
}
}
}
}
for filter := range rule.Filters {
if filter.FilterMatches(topic) {
return n, false
}
}
}
}

View File

@@ -1,13 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/stretchr/testify/require"
)
@@ -560,17 +561,17 @@ func TestLedgerUpdate(t *testing.T) {
},
}
new := &Ledger{
n := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(new)
old.Update(n)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, new, old)
require.NotSame(t, n, old)
}
func TestLedgerToJSON(t *testing.T) {

View File

@@ -1,16 +1,17 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (
"fmt"
"log/slog"
"strings"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/rs/zerolog"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
)
// Options contains configuration settings for the debug output.
@@ -24,7 +25,7 @@ type Options struct {
type Hook struct {
mqtt.HookBase
config *Options
Log *zerolog.Logger
Log *slog.Logger
}
// ID returns the ID of the hook.
@@ -53,25 +54,25 @@ func (h *Hook) Init(config any) error {
}
// SetOpts is called when the hook receives inheritable server parameters.
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) {
h.Log = l
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
h.Log.Debug("", "method", "SetOpts")
}
// Stop is called when the hook is stopped.
func (h *Hook) Stop() error {
h.Log.Debug().Str("method", "Stop").Send()
h.Log.Debug("", "method", "Stop")
return nil
}
// OnStarted is called when the server starts.
func (h *Hook) OnStarted() {
h.Log.Debug().Str("method", "OnStarted").Send()
h.Log.Debug("", "method", "OnStarted")
}
// OnStopped is called when the server stops.
func (h *Hook) OnStopped() {
h.Log.Debug().Str("method", "OnStopped").Send()
h.Log.Debug("", "method", "OnStopped")
}
// OnPacketRead is called when a new packet is received from a client.
@@ -80,8 +81,7 @@ func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet,
return pk, nil
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
return pk, nil
}
@@ -91,85 +91,72 @@ func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
return
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk))
}
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
h.Log.Debug("retained message on topic", "m", h.packetMeta(pk))
}
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
h.Log.Debug("inflight out", "m", h.packetMeta(pk))
}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
h.Log.Debug("inflight complete", "m", h.packetMeta(pk))
}
// OnQosDropped is called the Qos flow for a message expires.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
h.Log.Debug("inflight dropped", "m", h.packetMeta(pk))
}
// OnLWTSent is called when a will message has been issued from a disconnecting client.
// OnLWTSent is called when a Will Message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID)
}
// OnRetainedExpired is called when the server clears expired retained messages.
func (h *Hook) OnRetainedExpired(filter string) {
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter)
}
// OnClientExpired is called when the server clears an expired client.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID)
}
// StoredClients is called when the server restores clients from a store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
h.Log.Debug("", "method", "StoredClients")
return v, nil
}
// StoredClients is called when the server restores subscriptions from a store.
// StoredSubscriptions is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug().
Str("method", "StoredSubscriptions").
Send()
h.Log.Debug("", "method", "StoredSubscriptions")
return v, nil
}
// StoredClients is called when the server restores retained messages from a store.
// StoredRetainedMessages is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredRetainedMessages").
Send()
h.Log.Debug("", "method", "StoredRetainedMessages")
return v, nil
}
// StoredClients is called when the server restores inflight messages from a store.
// StoredInflightMessages is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredInflightMessages").
Send()
h.Log.Debug("", "method", "StoredInflightMessages")
return v, nil
}
// StoredClients is called when the server restores system info from a store.
// StoredSysInfo is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
h.Log.Debug("", "method", "StoredSysInfo")
return v, nil
}

View File

@@ -1,17 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co, gsagula
package badger
import (
"bytes"
"errors"
"fmt"
"strings"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/timshannon/badgerhold"
)
@@ -79,7 +81,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -127,8 +128,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
@@ -136,7 +136,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -165,14 +165,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
h.Log.Error("failed to upsert client data", "error", err, "data", in)
}
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -182,32 +182,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
h.Log.Error("failed to delete client data", "error", err, "data", clientKey(cl))
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
h.Log.Error("failed to upsert subscription data", "error", err, "data", in)
}
}
}
@@ -215,14 +223,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
if err != nil {
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
h.Log.Error("failed to delete subscription data", "error", err, "data", subscriptionKey(cl, pk.Filters[i].Filter))
}
}
}
@@ -230,14 +238,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
h.Log.Error("failed to delete retained message data", "error", err, "data", retainedKey(pk.TopicName))
}
return
@@ -266,14 +274,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
h.Log.Error("failed to upsert retained message data", "error", err, "data", in)
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -302,27 +310,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
h.Log.Error("failed to upsert qos inflight data", "error", err, "data", in)
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
h.Log.Error("failed to delete inflight message data", "error", err, "data", inflightKey(cl, pk))
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
@@ -331,66 +339,52 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
Info: *sys.Clone(),
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
h.Log.Error("failed to upsert $SYS data", "error", err, "data", in)
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
h.Log.Error("failed to delete expired retained message data", "error", err, "id", retainedKey(filter))
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
h.Log.Error("failed to delete expired client data", "error", err, "id", clientKey(cl))
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -405,7 +399,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -420,7 +414,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -435,7 +429,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -450,7 +444,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -464,20 +458,21 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
// Errorf satisfies the badger interface for an error logger.
func (h *Hook) Errorf(m string, v ...interface{}) {
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
h.Log.Error(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Warningf satisfies the badger interface for a warning logger.
func (h *Hook) Warningf(m string, v ...interface{}) {
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
h.Log.Warn(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Infof satisfies the badger interface for an info logger.
func (h *Hook) Infof(m string, v ...interface{}) {
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
h.Log.Info(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}
// Debugf satisfies the badger interface for a debug logger.
func (h *Hook) Debugf(m string, v ...interface{}) {
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
h.Log.Debug(fmt.Sprintf(strings.ToLower(strings.Trim(m, "\n")), v...), "v", v)
}

View File

@@ -1,27 +1,26 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"errors"
"log/slog"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/rs/zerolog"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
"github.com/timshannon/badgerhold"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
@@ -39,8 +38,8 @@ var (
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
h.db.Badger().Close()
_ = h.Stop()
_ = h.db.Badger().Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
@@ -96,7 +95,7 @@ func TestProvides(t *testing.T) {
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
@@ -104,7 +103,7 @@ func TestInitBadConfig(t *testing.T) {
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -114,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -147,7 +146,7 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -169,15 +168,30 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -186,7 +200,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -205,22 +219,45 @@ func TestOnWillSent(t *testing.T) {
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -242,13 +279,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -257,13 +294,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -272,7 +309,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -307,7 +344,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -332,15 +369,30 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -349,7 +401,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -384,13 +436,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -399,13 +451,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -414,55 +466,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -484,13 +494,13 @@ func TestOnSysInfoTick(t *testing.T) {
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -499,7 +509,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -524,7 +534,7 @@ func TestStoredClients(t *testing.T) {
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
@@ -532,7 +542,7 @@ func TestStoredClientsNoDB(t *testing.T) {
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -557,7 +567,7 @@ func TestStoredSubscriptions(t *testing.T) {
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
@@ -565,7 +575,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -593,7 +603,7 @@ func TestStoredRetainedMessages(t *testing.T) {
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
@@ -601,7 +611,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -629,7 +639,7 @@ func TestStoredInflightMessages(t *testing.T) {
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
@@ -637,7 +647,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -659,7 +669,7 @@ func TestStoredSysInfo(t *testing.T) {
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
@@ -668,27 +678,27 @@ func TestStoredSysInfoNoDB(t *testing.T) {
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.Debugf("test", 1, 2, 3)
}

View File

@@ -1,7 +1,8 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
// Package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
@@ -9,10 +10,10 @@ import (
"errors"
"time"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
@@ -85,7 +86,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -133,8 +133,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
@@ -142,7 +141,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -170,14 +169,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
h.Log.Error("failed to save client data", "error", err, "data", in)
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -185,34 +184,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save subscription data")
h.Log.Error("failed to save subscription data", "error", err, "client", cl.ID, "data", in)
}
}
}
@@ -220,7 +225,7 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -229,9 +234,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
ID: subscriptionKey(cl, pk.Filters[i].Filter),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
Msg("failed to delete client")
h.Log.Error("failed to delete client", "error", err, "id", subscriptionKey(cl, pk.Filters[i].Filter))
}
}
}
@@ -239,7 +242,7 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -248,9 +251,7 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
ID: retainedKey(pk.TopicName),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", retainedKey(pk.TopicName)).
Msg("failed to delete retained publish")
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(pk.TopicName))
}
return
}
@@ -277,17 +278,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save retained publish data")
h.Log.Error("failed to save retained publish data", "error", err, "client", cl.ID, "data", in)
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -315,17 +313,14 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save qos inflight data")
h.Log.Error("failed to save qos inflight data", "error", err, "client", cl.ID, "data", in)
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -333,16 +328,14 @@ func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
ID: inflightKey(cl, pk),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", inflightKey(cl, pk)).
Msg("failed to delete inflight data")
h.Log.Error("failed to delete inflight data", "error", err, "id", inflightKey(cl, pk))
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
@@ -351,7 +344,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -363,57 +356,39 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Interface("data", in).
Msg("failed to save $SYS data")
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
h.Log.Error("failed to save $SYS data", "error", err, "data", in)
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
h.Log.Error("failed to delete retained publish", "error", err, "id", retainedKey(filter))
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -428,7 +403,7 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -443,7 +418,7 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -458,7 +433,7 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -473,7 +448,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}

View File

@@ -1,26 +1,26 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"errors"
"log/slog"
"os"
"testing"
"time"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/asdine/storm/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
@@ -38,7 +38,7 @@ var (
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
_ = h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}
@@ -94,7 +94,7 @@ func TestProvides(t *testing.T) {
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
@@ -102,7 +102,7 @@ func TestInitBadConfig(t *testing.T) {
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -113,7 +113,7 @@ func TestInitUseDefaults(t *testing.T) {
func TestInitBadPath(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Path: "..",
})
@@ -122,7 +122,7 @@ func TestInitBadPath(t *testing.T) {
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -155,13 +155,13 @@ func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -170,7 +170,7 @@ func TestOnSessionEstablishedClosedDB(t *testing.T) {
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -189,7 +189,7 @@ func TestOnWillSent(t *testing.T) {
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -211,24 +211,62 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -250,13 +288,13 @@ func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -265,13 +303,13 @@ func TestOnSubscribedClosedDB(t *testing.T) {
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -280,7 +318,7 @@ func TestOnUnsubscribedClosedDB(t *testing.T) {
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -315,7 +353,7 @@ func TestOnRetainMessageThenUnset(t *testing.T) {
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -340,15 +378,30 @@ func TestOnRetainedExpired(t *testing.T) {
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -357,7 +410,7 @@ func TestOnRetainMessageClosedDB(t *testing.T) {
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -392,13 +445,13 @@ func TestOnQosPublishThenQOSComplete(t *testing.T) {
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -407,13 +460,13 @@ func TestOnQosPublishClosedDB(t *testing.T) {
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -422,55 +475,13 @@ func TestOnQosCompleteClosedDB(t *testing.T) {
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -492,13 +503,13 @@ func TestOnSysInfoTick(t *testing.T) {
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -507,7 +518,7 @@ func TestOnSysInfoTickClosedDB(t *testing.T) {
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -532,7 +543,7 @@ func TestStoredClients(t *testing.T) {
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
@@ -540,7 +551,7 @@ func TestStoredClientsNoDB(t *testing.T) {
func TestStoredClientsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -551,7 +562,7 @@ func TestStoredClientsClosedDB(t *testing.T) {
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -576,7 +587,7 @@ func TestStoredSubscriptions(t *testing.T) {
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
@@ -584,7 +595,7 @@ func TestStoredSubscriptionsNoDB(t *testing.T) {
func TestStoredSubscriptionsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -595,7 +606,7 @@ func TestStoredSubscriptionsClosedDB(t *testing.T) {
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -623,7 +634,7 @@ func TestStoredRetainedMessages(t *testing.T) {
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
@@ -631,7 +642,7 @@ func TestStoredRetainedMessagesNoDB(t *testing.T) {
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -642,7 +653,7 @@ func TestStoredRetainedMessagesClosedDB(t *testing.T) {
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -670,7 +681,7 @@ func TestStoredInflightMessages(t *testing.T) {
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
@@ -678,7 +689,7 @@ func TestStoredInflightMessagesNoDB(t *testing.T) {
func TestStoredInflightMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
@@ -689,7 +700,7 @@ func TestStoredInflightMessagesClosedDB(t *testing.T) {
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
@@ -711,7 +722,7 @@ func TestStoredSysInfo(t *testing.T) {
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
@@ -719,7 +730,7 @@ func TestStoredSysInfoNoDB(t *testing.T) {
func TestStoredSysInfoClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
@@ -9,12 +10,12 @@ import (
"errors"
"fmt"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
redis "github.com/go-redis/redis/v8"
"github.com/go-redis/redis/v8"
)
// defaultAddr is the default address to the redis service.
@@ -82,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -117,12 +117,11 @@ func (h *Hook) Init(config any) error {
h.config.HPrefix = defaultHPrefix
}
h.Log.Info().
Str("address", h.config.Options.Addr).
Str("username", h.config.Options.Username).
Int("password-len", len(h.config.Options.Password)).
Int("db", h.config.Options.DB).
Msg("connecting to redis service")
h.Log.Info("connecting to redis service",
"address", h.config.Options.Addr,
"username", h.config.Options.Username,
"password-len", len(h.config.Options.Password),
"db", h.config.Options.DB)
h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result()
@@ -130,14 +129,15 @@ func (h *Hook) Init(config any) error {
return fmt.Errorf("failed to ping service: %w", err)
}
h.Log.Info().Msg("connected to redis service")
h.Log.Info("connected to redis service")
return nil
}
// Close closes the redis connection.
// Stop closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info().Msg("disconnecting from redis service")
h.Log.Info("disconnecting from redis service")
return h.db.Close()
}
@@ -146,8 +146,7 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
// OnWillSent is called when a client sends a Will Message and the Will Message is removed from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
@@ -155,7 +154,7 @@ func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -184,14 +183,14 @@ func (h *Hook) updateClient(cl *mqtt.Client) {
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
h.Log.Error("failed to hset client data", "error", err, "data", in)
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -199,32 +198,40 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
h.Log.Error("failed to delete client", "error", err, "id", clientKey(cl))
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
h.Log.Error("failed to hset subscription data", "error", err, "data", in)
}
}
}
@@ -232,14 +239,14 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
h.Log.Error("failed to delete subscription data", "error", err, "id", clientKey(cl))
}
}
}
@@ -247,14 +254,14 @@ func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
h.Log.Error("failed to delete retained message data", "error", err, "id", retainedKey(pk.TopicName))
}
return
@@ -283,14 +290,14 @@ func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
h.Log.Error("failed to hset retained message data", "error", err, "data", in)
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -318,27 +325,27 @@ func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, rese
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
h.Log.Error("failed to hset qos inflight message data", "error", err, "data", in)
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
h.Log.Error("failed to delete qos inflight message data", "error", err, "id", inflightKey(cl, pk))
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
@@ -347,7 +354,7 @@ func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -359,72 +366,53 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
h.Log.Error("failed to hset server info data", "error", err, "data", in)
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
h.Log.Error("failed to delete expired retained message", "error", err, "id", retainedKey(filter))
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
h.Log.Error("failed to delete expired client", "error", err, "id", clientKey(cl))
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
h.Log.Error("failed to HGetAll client data", "error", err)
return
}
for _, row := range rows {
var d storage.Client
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
h.Log.Error("failed to unmarshal client data", "error", err, "data", row)
}
v = append(v, d)
@@ -436,20 +424,20 @@ func (h *Hook) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
h.Log.Error("failed to HGetAll subscription data", "error", err)
return
}
for _, row := range rows {
var d storage.Subscription
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
h.Log.Error("failed to unmarshal subscription data", "error", err, "data", row)
}
v = append(v, d)
@@ -461,20 +449,20 @@ func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
h.Log.Error("failed to HGetAll retained message data", "error", err)
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
h.Log.Error("failed to unmarshal retained message data", "error", err, "data", row)
}
v = append(v, d)
@@ -486,20 +474,20 @@ func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
h.Log.Error("failed to HGetAll inflight message data", "error", err)
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
h.Log.Error("failed to unmarshal inflight message data", "error", err, "data", row)
}
v = append(v, d)
@@ -511,7 +499,7 @@ func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
h.Log.Error("", "error", storage.ErrDBFileNotOpen)
return
}
@@ -521,7 +509,7 @@ func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
}
if err = v.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
h.Log.Error("failed to unmarshal sys info data", "error", err, "data", row)
}
return v, nil

View File

@@ -1,27 +1,28 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"log/slog"
"os"
"sort"
"testing"
"time"
"github.com/mochi-co/mqtt"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
miniredis "github.com/alicebob/miniredis/v2"
redis "github.com/go-redis/redis/v8"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
client = &mqtt.Client{
ID: "test",
@@ -40,7 +41,7 @@ var (
func newHook(t *testing.T, addr string) *Hook {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
@@ -86,13 +87,13 @@ func TestSysInfoKey(t *testing.T) {
func TestID(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.Equal(t, "redis-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
@@ -115,7 +116,7 @@ func TestHKey(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
}
@@ -125,7 +126,7 @@ func TestInitUseDefaults(t *testing.T) {
defer s.Close()
h := newHook(t, defaultAddr)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h)
@@ -136,7 +137,7 @@ func TestInitUseDefaults(t *testing.T) {
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
@@ -144,7 +145,7 @@ func TestInitBadConfig(t *testing.T) {
func TestInitBadAddr(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.SetOpts(logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: "abc:123",
@@ -252,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -268,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -391,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -483,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -1,14 +1,15 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"encoding/json"
"errors"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
)
const (
@@ -24,7 +25,7 @@ var (
ErrDBFileNotOpen = errors.New("db file not open")
)
// Client is a storable representation of an mqtt client.
// Client is a storable representation of an MQTT client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
@@ -116,7 +117,37 @@ func (d *Message) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, d)
}
// Subscription is a storable representation of an mqtt subscription.
// ToPacket converts a storage.Message to a standard packet.
func (d *Message) ToPacket() packets.Packet {
pk := packets.Packet{
FixedHeader: d.FixedHeader,
PacketID: d.PacketID,
TopicName: d.TopicName,
Payload: d.Payload,
Origin: d.Origin,
Created: d.Created,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
}
// Return a deep copy of the packet data otherwise the slices will
// continue pointing at the values from the storage packet.
pk = pk.Copy(true)
pk.FixedHeader.Dup = d.FixedHeader.Dup
return pk
}
// Subscription is a storable representation of an MQTT subscription.
type Subscription struct {
T string `json:"t"`
ID string `json:"id" storm:"id"`

View File

@@ -1,14 +1,15 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"testing"
"time"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
)
@@ -103,6 +104,7 @@ var (
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
@@ -110,13 +112,13 @@ var (
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
data, err := clientStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, clientJSON, data)
require.JSONEq(t, string(clientJSON), string(data))
}
func TestClientUnmarshalBinary(t *testing.T) {
@@ -136,7 +138,7 @@ func TestClientUnmarshalBinaryEmpty(t *testing.T) {
func TestMessageMarshalBinary(t *testing.T) {
data, err := messageStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, messageJSON, data)
require.JSONEq(t, string(messageJSON), string(data))
}
func TestMessageUnmarshalBinary(t *testing.T) {
@@ -156,7 +158,7 @@ func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
func TestSubscriptionMarshalBinary(t *testing.T) {
data, err := subscriptionStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, subscriptionJSON, data)
require.JSONEq(t, string(subscriptionJSON), string(data))
}
func TestSubscriptionUnmarshalBinary(t *testing.T) {
@@ -176,7 +178,7 @@ func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
func TestSysInfoMarshalBinary(t *testing.T) {
data, err := sysInfoStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, sysInfoJSON, data)
require.JSONEq(t, string(sysInfoJSON), string(data))
}
func TestSysInfoUnmarshalBinary(t *testing.T) {
@@ -192,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}
func TestMessageToPacket(t *testing.T) {
d := messageStruct
pk := d.ToPacket()
require.Equal(t, packets.Packet{
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: d.FixedHeader.Remaining,
Type: d.FixedHeader.Type,
Qos: d.FixedHeader.Qos,
Dup: d.FixedHeader.Dup,
Retain: d.FixedHeader.Retain,
},
Origin: d.Origin,
TopicName: d.TopicName,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
PacketID: 100,
Created: d.Created,
}, pk)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -10,9 +11,9 @@ import (
"testing"
"time"
"github.com/mochi-co/mqtt/hooks/storage"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/system"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
)
@@ -26,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -45,6 +50,14 @@ func (h *modifiedHookBase) Stop() error {
return nil
}
func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
@@ -177,12 +190,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
@@ -194,7 +215,7 @@ func TestHooksAddInitFailure(t *testing.T) {
func TestHooksStop(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
@@ -215,7 +236,7 @@ func TestHooksNonReturns(t *testing.T) {
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnConnect(cl, packets.Packet{})
h.OnSessionEstablish(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
@@ -223,14 +244,16 @@ func TestHooksNonReturns(t *testing.T) {
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnRetainPublished(cl, packets.Packet{})
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnPacketIDExhausted(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
@@ -311,7 +334,7 @@ func TestHooksOnUnsubscribe(t *testing.T) {
func TestHooksOnPublish(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
@@ -324,7 +347,7 @@ func TestHooksOnPublish(t *testing.T) {
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
@@ -337,7 +360,7 @@ func TestHooksOnPublish(t *testing.T) {
func TestHooksOnPacketRead(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
@@ -363,7 +386,7 @@ func TestHooksOnPacketRead(t *testing.T) {
func TestHooksOnAuthPacket(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
@@ -379,9 +402,25 @@ func TestHooksOnAuthPacket(t *testing.T) {
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnConnect(t *testing.T) {
h := new(Hooks)
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
hook.fail = true
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
@@ -393,7 +432,7 @@ func TestHooksOnPacketEncode(t *testing.T) {
func TestHooksOnLWT(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
@@ -410,7 +449,7 @@ func TestHooksOnLWT(t *testing.T) {
func TestHooksStoredClients(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
v, err := h.StoredClients()
require.NoError(t, err)
@@ -432,7 +471,7 @@ func TestHooksStoredClients(t *testing.T) {
func TestHooksStoredSubscriptions(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
v, err := h.StoredSubscriptions()
require.NoError(t, err)
@@ -454,7 +493,7 @@ func TestHooksStoredSubscriptions(t *testing.T) {
func TestHooksStoredRetainedMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
@@ -476,7 +515,7 @@ func TestHooksStoredRetainedMessages(t *testing.T) {
func TestHooksStoredInflightMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
v, err := h.StoredInflightMessages()
require.NoError(t, err)
@@ -498,7 +537,7 @@ func TestHooksStoredInflightMessages(t *testing.T) {
func TestHooksStoredSysInfo(t *testing.T) {
h := new(Hooks)
h.Log = &logger
h.Log = logger
v, err := h.StoredSysInfo()
require.NoError(t, err)
@@ -536,7 +575,7 @@ func TestHookBaseInit(t *testing.T) {
func TestHookBaseSetOpts(t *testing.T) {
h := new(HookBase)
h.SetOpts(&logger, new(HookOptions))
h.SetOpts(logger, new(HookOptions))
require.NotNil(t, h.Log)
require.NotNil(t, h.Opts)
}
@@ -551,12 +590,19 @@ func TestHookBaseOnConnectAuthenticate(t *testing.T) {
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnConnect(t *testing.T) {
h := new(HookBase)
err := h.OnConnect(new(Client), packets.Packet{})
require.NoError(t, err)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -8,7 +9,7 @@ import (
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2/packets"
)
// Inflight is a map of InflightMessage keyed on packet id.
@@ -57,6 +58,18 @@ func (i *Inflight) Len() int {
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
@@ -103,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool {
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) TakeReceiveQuota() {
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) ReturnReceiveQuota() {
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
@@ -122,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// TakeSendQuota reduces the send quota by 1.
func (i *Inflight) TakeSendQuota() {
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// ReturnSendQuota increases the send quota by 1.
func (i *Inflight) ReturnSendQuota() {
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}

View File

@@ -1,18 +1,19 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sync/atomic"
"testing"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/stretchr/testify/require"
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
@@ -24,7 +25,7 @@ func TestInflightSet(t *testing.T) {
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
@@ -33,7 +34,7 @@ func TestInflightGet(t *testing.T) {
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
@@ -55,13 +56,23 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
@@ -94,12 +105,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
@@ -109,12 +120,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
@@ -136,12 +147,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
@@ -151,18 +162,18 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newClient()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})

View File

@@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: Derek Duncan
package listeners
import (
"context"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
)
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
type HTTPHealthCheck struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
end uint32 // ensure the close methods are only called once
}
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address.
func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck {
if config == nil {
config = new(Config)
}
return &HTTPHealthCheck{
id: id,
address: address,
config: config,
}
}
// ID returns the id of the listener.
func (l *HTTPHealthCheck) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *HTTPHealthCheck) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *HTTPHealthCheck) Protocol() string {
if l.listen != nil && l.listen.TLSConfig != nil {
return "https"
}
return "http"
}
// Init initializes the listener.
func (l *HTTPHealthCheck) Init(_ *slog.Logger) error {
mux := http.NewServeMux()
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
}
})
l.listen = &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
Addr: l.address,
Handler: mux,
}
if l.config.TLSConfig != nil {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
_ = l.listen.ListenAndServeTLS("", "")
} else {
_ = l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)
}

View File

@@ -0,0 +1,143 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: Derek Duncan
package listeners
import (
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewHTTPHealthCheck(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "healthcheck", l.id)
require.Equal(t, testAddr, l.address)
}
func TestHTTPHealthCheckID(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "healthcheck", l.ID())
}
func TestHTTPHealthCheckAddress(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestHTTPHealthCheckProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "http", l.Protocol())
}
func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
_ = l.Init(logger)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPHealthCheckInit(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(logger)
require.NoError(t, err)
require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr)
}
func TestHTTPHealthCheckServeAndClose(t *testing.T) {
// setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// call healthcheck
resp, err := http.Get("http://localhost" + testAddr + "/healthcheck")
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck")
require.Error(t, err)
<-o
}
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
// setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// make disallowed method type http request
resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody)
require.Error(t, err)
<-o
}
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.Close(MockCloser)
}

View File

@@ -1,32 +1,31 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/system"
"github.com/rs/zerolog"
"github.com/mochi-mqtt/server/v2/system"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
log *zerolog.Logger // server logger
sysInfo *system.Info // pointers to the server data
end uint32 // ensure the close methods are only called once
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
sysInfo *system.Info // pointers to the server data
end uint32 // ensure the close methods are only called once
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
@@ -62,9 +61,7 @@ func (l *HTTPStats) Protocol() string {
}
// Init initializes the listener.
func (l *HTTPStats) Init(log *zerolog.Logger) error {
l.log = log
func (l *HTTPStats) Init(_ *slog.Logger) error {
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
@@ -84,9 +81,9 @@ func (l *HTTPStats) Init(log *zerolog.Logger) error {
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
_ = l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
_ = l.listen.ListenAndServe()
}
}
@@ -98,7 +95,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)
@@ -106,33 +103,12 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := &system.Info{
Version: l.sysInfo.Version,
Started: atomic.LoadInt64(&l.sysInfo.Started),
Time: atomic.LoadInt64(&l.sysInfo.Time),
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
}
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
_, _ = io.WriteString(w, err.Error())
}
w.Write(out)
_, _ = w.Write(out)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -10,7 +11,7 @@ import (
"testing"
"time"
"github.com/mochi-co/mqtt/system"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
)
@@ -41,14 +42,14 @@ func TestHTTPStatsTLSProtocol(t *testing.T) {
TLSConfig: tlsConfigBasic,
}, nil)
l.Init(nil)
_ = l.Init(logger)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
err := l.Init(logger)
require.NoError(t, err)
require.NotNil(t, l.sysInfo)
@@ -64,7 +65,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) {
// setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
@@ -112,7 +113,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(nil)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -8,7 +9,7 @@ import (
"net"
"sync"
"github.com/rs/zerolog"
"log/slog"
)
// Config contains configuration values for a listener.
@@ -21,18 +22,18 @@ type Config struct {
// EstablishFn is a callback function for establishing new clients.
type EstablishFn func(id string, c net.Conn) error
// CloseFunc is a callback function for closing all listener clients.
// CloseFn is a callback function for closing all listener clients.
type CloseFn func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
Init(*zerolog.Logger) error // open the network address
Serve(EstablishFn) // starting actively listening for new connections
ID() string // return the id of the listener
Address() string // the address of the listener
Protocol() string // the protocol in use by the listener
Close(CloseFn) // stop and close the listener
Init(*slog.Logger) error // open the network address
Serve(EstablishFn) // starting actively listening for new connections
ID() string // return the id of the listener
Address() string // the address of the listener
Protocol() string // the protocol in use by the listener
Close(CloseFn) // stop and close the listener
}
// Listeners contains the network listeners for the broker.

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -10,14 +11,15 @@ import (
"testing"
"time"
"github.com/rs/zerolog"
"log/slog"
"github.com/stretchr/testify/require"
)
const testAddr = ":22222"
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -8,7 +9,7 @@ import (
"net"
"sync"
"github.com/rs/zerolog"
"log/slog"
)
// MockEstablisher is a function signature which can be used in testing.
@@ -52,7 +53,7 @@ func (l *MockListener) Serve(establisher EstablishFn) {
}
// Init initializes the listener.
func (l *MockListener) Init(log *zerolog.Logger) error {
func (l *MockListener) Init(log *slog.Logger) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -15,7 +16,7 @@ func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w)
require.NoError(t, err)
w.Close()
_ = w.Close()
}
func TestNewMockListener(t *testing.T) {
@@ -85,7 +86,7 @@ func TestMockListenerServe(t *testing.T) {
require.Equal(t, true, closed)
<-o
mocked.Init(nil)
_ = mocked.Init(nil)
}
func TestMockListenerClose(t *testing.T) {

92
listeners/net.go Normal file
View File

@@ -0,0 +1,92 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: Jeroen Rinzema
package listeners
import (
"net"
"sync"
"sync/atomic"
"log/slog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *slog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn("", "error", err)
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

105
listeners/net_test.go Normal file
View File

@@ -0,0 +1,105 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
_, _ = net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -9,18 +10,18 @@ import (
"sync"
"sync/atomic"
"github.com/rs/zerolog"
"log/slog"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
id string // the internal id of the listener
address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener
log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
@@ -52,7 +53,7 @@ func (l *TCP) Protocol() string {
}
// Init initializes the listener.
func (l *TCP) Init(log *zerolog.Logger) error {
func (l *TCP) Init(log *slog.Logger) error {
l.log = log
var err error
@@ -82,7 +83,7 @@ func (l *TCP) Serve(establish EstablishFn) {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
l.log.Warn("", "error", err)
}
}()
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -38,21 +39,21 @@ func TestTCPProtocolTLS(t *testing.T) {
TLSConfig: tlsConfigBasic,
})
l.Init(&logger)
_ = l.Init(logger)
defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
err := l.Init(logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err = l2.Init(&logger)
err = l2.Init(logger)
l2.Close(MockCloser)
require.NoError(t, err)
require.NotNil(t, l2.config.TLSConfig)
@@ -60,7 +61,7 @@ func TestTCPInit(t *testing.T) {
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
@@ -87,7 +88,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(&logger)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
@@ -109,7 +110,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
@@ -123,7 +124,7 @@ func TestTCPEstablishThenEnd(t *testing.T) {
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", l.listen.Addr().String())
_, _ = net.Dial("tcp", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o

98
listeners/unixsock.go Normal file
View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"log/slog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *slog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn("", "error", err)
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
_, _ = net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -1,19 +1,22 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"log/slog"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var (
@@ -27,8 +30,8 @@ type Websocket struct { // [MQTT-4.2.0-1]
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // an http server for serving websocket connections
log *zerolog.Logger // server logger
listen *http.Server // a http server for serving websocket connections
log *slog.Logger // server logger
establish EstablishFn // the server's establish connection handler
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
end uint32 // ensure the close methods are only called once
@@ -73,7 +76,7 @@ func (l *Websocket) Protocol() string {
}
// Init initializes the listener.
func (l *Websocket) Init(log *zerolog.Logger) error {
func (l *Websocket) Init(log *slog.Logger) error {
l.log = log
mux := http.NewServeMux()
@@ -97,9 +100,9 @@ func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
}
defer c.Close()
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c})
if err != nil {
l.log.Warn().Err(err).Send()
l.log.Warn("", "error", err)
}
}
@@ -109,9 +112,9 @@ func (l *Websocket) Serve(establish EstablishFn) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
_ = l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
_ = l.listen.ListenAndServe()
}
}
@@ -123,7 +126,7 @@ func (l *Websocket) Close(closeClients CloseFn) {
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
_ = l.listen.Shutdown(ctx)
}
closeClients(l.id)
@@ -133,28 +136,54 @@ func (l *Websocket) Close(closeClients CloseFn) {
type wsConn struct {
net.Conn
c *websocket.Conn
// reader for the current message (can be nil)
r io.Reader
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (n int, err error) {
op, r, err := ws.c.NextReader()
if err != nil {
return
func (ws *wsConn) Read(p []byte) (int, error) {
if ws.r == nil {
op, r, err := ws.c.NextReader()
if err != nil {
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return 0, err
}
ws.r = r
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return
}
var n int
for {
// buffer is full, return what we've read so far
if n == len(p) {
return n, nil
}
return r.Read(p)
br, err := ws.r.Read(p[n:])
n += br
if err != nil {
// when ANY error occurs, we consider this the end of the current message (either because it really is, via
// io.EOF, or because something bad happened, in which case we want to drop the remainder)
ws.r = nil
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (n int, err error) {
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return
return 0, err
}
return len(p), nil

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
@@ -36,24 +37,24 @@ func TestWebsocketProtocol(t *testing.T) {
require.Equal(t, "ws", l.Protocol())
}
func TestWebsocketProtocoTLS(t *testing.T) {
func TestWebsocketProtocolTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol())
}
func TestWebsockeInit(t *testing.T) {
func TestWebsocketInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Nil(t, l.listen)
err := l.Init(nil)
err := l.Init(logger)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
_ = l.Init(logger)
o := make(chan bool)
go func(o chan bool) {
@@ -76,7 +77,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(nil)
err := l.Init(logger)
require.NoError(t, err)
o := make(chan bool)
@@ -95,7 +96,7 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
_ = l.Init(logger)
e := make(chan bool)
l.establish = func(id string, c net.Conn) error {
@@ -109,5 +110,46 @@ func TestWebsocketUpgrade(t *testing.T) {
require.Equal(t, true, <-e)
s.Close()
ws.Close()
_ = ws.Close()
}
func TestWebsocketConnectionReads(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
_ = l.Init(nil)
recv := make(chan []byte)
l.establish = func(id string, c net.Conn) error {
var out []byte
for {
buf := make([]byte, 2048)
n, err := c.Read(buf)
require.NoError(t, err)
out = append(out, buf[:n]...)
if n < 2048 {
break
}
}
recv <- out
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
require.NoError(t, err)
pkt := make([]byte, 3000) // make sure this is >2048
for i := 0; i < len(pkt); i++ {
pkt[i] = byte(i % 100)
}
err = ws.WriteMessage(websocket.BinaryMessage, pkt)
require.NoError(t, err)
got := <-recv
require.Equal(t, 3000, len(got))
require.Equal(t, pkt, got)
s.Close()
_ = ws.Close()
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -375,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(65535)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
@@ -20,13 +21,14 @@ func (c Code) Error() string {
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
2: CodeGrantedQos2,
}
CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"}
CodeSuccess = Code{Code: 0x00, Reason: "success"}
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
@@ -112,15 +114,36 @@ var (
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -18,7 +19,7 @@ func TestCodesString(t *testing.T) {
require.Equal(t, "test", c.String())
}
func TestCodesErrorr(t *testing.T) {
func TestCodesError(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

View File

@@ -1,42 +1,45 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
)
// All of the valid packet types and their packet identifier.
// All valid packet types and their packet identifiers.
const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
WillProperties byte = 99 // Special byte for validating Will Properties.
)
var (
// ErrNoValidPacketAvailable indicates the packet type byte provided does not exist in the mqtt specification.
ErrNoValidPacketAvailable error = errors.New("no valid packet available")
ErrNoValidPacketAvailable = errors.New("no valid packet available")
// PacketNames is a map of packet bytes to human readable names, for easier debugging.
// PacketNames is a map of packet bytes to human-readable names, for easier debugging.
PacketNames = map[byte]string{
0: "Reserved",
1: "Connect",
@@ -132,6 +135,7 @@ type Packet struct {
SessionPresent bool // session existed for connack
ReasonCode byte // reason code for a packet response (acks, etc)
ReservedBit byte // reserved, do not use (except in testing)
Ignore bool // if true, do not perform any message forwarding operations
}
// Mods specifies certain values required for certain mqtt v5 compliance within packet encoding/decoding.
@@ -172,6 +176,7 @@ type Subscription struct {
Qos byte
RetainAsPublished bool
NoLocal bool
FwdRetainedFlag bool // true if the subscription forms part of a publish response to a client subscription and packet is retained.
}
// Copy creates a new instance of a packet, but with an empty header for inheriting new QoS flags, etc.
@@ -207,7 +212,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
Created: pk.Created,
Expiry: pk.Expiry,
Origin: pk.Origin,
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
}
if allowTransfer {
p.PacketID = pk.PacketID
}
if len(pk.Connect.ProtocolName) > 0 {
@@ -264,28 +272,28 @@ func (s Subscription) Merge(n Subscription) Subscription {
}
// encode encodes a subscription and properties into bytes.
func (p Subscription) encode() byte {
func (s Subscription) encode() byte {
var flag byte
flag |= p.Qos
flag |= s.Qos
if p.NoLocal {
if s.NoLocal {
flag |= 1 << 2
}
if p.RetainAsPublished {
if s.RetainAsPublished {
flag |= 1 << 3
}
flag |= p.RetainHandling << 4
flag |= s.RetainHandling << 4
return flag
}
// decode decodes subscription bytes into a subscription struct.
func (p *Subscription) decode(b byte) {
p.Qos = b & 3 // byte
p.NoLocal = 1&(b>>2) > 0 // bool
p.RetainAsPublished = 1&(b>>3) > 0 // bool
p.RetainHandling = 3 & (b >> 4) // byte
func (s *Subscription) decode(b byte) {
s.Qos = b & 3 // byte
s.NoLocal = 1&(b>>2) > 0 // bool
s.RetainAsPublished = 1&(b>>3) > 0 // bool
s.RetainHandling = 3 & (b >> 4) // byte
}
// ConnectEncode encodes a connect packet.
@@ -308,7 +316,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Properties).Encode(pk, pb, 0)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -317,7 +325,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -335,7 +343,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -378,21 +386,21 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) // [MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
if err != nil {
return ErrClientIdentifierNotValid // [MQTT-3.1.3-8]
}
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
if pk.ProtocolVersion == 5 {
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
if err != nil {
return ErrMalformedWillProperties
}
offset += n + 1
offset += n
}
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
@@ -407,6 +415,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
if offset >= len(buf) { // we are at the end of the packet
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
pk.Connect.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
@@ -438,18 +450,14 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
}
if len(pk.Connect.Password) > 65535 {
if len(pk.Connect.Password) > math.MaxUint16 {
return ErrProtocolViolationPasswordTooLong
}
if len(pk.Connect.Username) > 65535 {
if len(pk.Connect.Username) > math.MaxUint16 {
return ErrProtocolViolationUsernameTooLong
}
if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 {
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
}
@@ -462,7 +470,7 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
}
if len(pk.Connect.ClientIdentifier) > 65535 {
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
return ErrClientIdentifierNotValid
}
@@ -491,13 +499,13 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -534,13 +542,13 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -603,7 +611,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}
@@ -611,7 +619,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -639,7 +647,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Payload = buf[offset:]
@@ -687,7 +695,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
}
@@ -699,7 +707,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -828,7 +836,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
@@ -836,7 +844,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -856,7 +864,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.ReasonCodes = buf[offset:]
@@ -885,7 +893,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -893,7 +901,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -913,7 +921,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -980,7 +988,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -988,7 +996,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -1009,7 +1017,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
pk.ReasonCodes = buf[offset:]
}
@@ -1033,7 +1041,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -1041,7 +1049,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}
@@ -1061,7 +1069,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -1096,12 +1104,12 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
pk.FixedHeader.Remaining = nb.Len()
pk.FixedHeader.Encode(buf)
nb.WriteTo(buf)
_, _ = nb.WriteTo(buf)
return nil
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -149,7 +150,7 @@ func TestPacketEncode(t *testing.T) {
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
_ = copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
@@ -217,7 +218,7 @@ func TestPacketDecode(t *testing.T) {
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
_ = pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
@@ -463,6 +464,9 @@ func TestCopy(t *testing.T) {
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -41,11 +42,11 @@ const (
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
@@ -53,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
@@ -63,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
@@ -76,7 +77,7 @@ type UserProperty struct { // [MQTT-1.5.7-1]
Val string `json:"v"`
}
// Properties contains all of the mqtt v5 properties available for a packet.
// Properties contains all mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
@@ -193,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
@@ -216,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
@@ -276,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
@@ -288,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
@@ -321,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
@@ -330,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes())
}
}
@@ -356,22 +355,23 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
encodeLength(b, int64(buf.Len()))
buf.WriteTo(b) // [MQTT-3.1.3-10]
_, _ = buf.WriteTo(b) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
n, _, err = DecodeLength(b)
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n, err
return n + bu, err
}
if n == 0 {
return n, nil
return n + bu, nil
}
bt := b.Bytes()
@@ -379,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
}
switch k {
@@ -405,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n, err
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
@@ -451,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
@@ -469,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
}
if err != nil {
return n, err
return n + bu, err
}
}
return n, nil
return n + bu, nil
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
@@ -201,14 +202,14 @@ func init() {
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
@@ -218,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
@@ -231,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
@@ -239,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
@@ -249,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172, n)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
// TPacketCase contains data for cross-checking the encoding and decoding
@@ -39,7 +40,6 @@ const (
TConnectMqtt5
TConnectMqtt5LWT
TConnectClean
TConnectCleanLWT
TConnectUserPass
TConnectUserPassLWT
TConnectMalProtocolName
@@ -60,7 +60,6 @@ const (
TConnectInvalidProtocolVersion2
TConnectInvalidReservedBit
TConnectInvalidClientIDTooLong
TConnectInvalidPasswordNoUsername
TConnectInvalidFlagNoUsername
TConnectInvalidFlagNoPassword
TConnectInvalidUsernameNoFlag
@@ -70,7 +69,7 @@ const (
TConnectInvalidWillFlagNoPayload
TConnectInvalidWillFlagQosOutOfRange
TConnectInvalidWillSurplusRetain
TConnectNotCleanNoClientID
TConnectZeroByteUsername
TConnectSpecInvalidUTF8D800
TConnectSpecInvalidUTF8DFFF
TConnectSpecInvalidUTF80000
@@ -81,6 +80,7 @@ const (
TConnackAcceptedAdjustedExpiryInterval
TConnackMinMqtt5
TConnackMinCleanMqtt5
TConnackServerKeepalive
TConnackInvalidMinMqtt5
TConnackBadProtocolVersion
TConnackProtocolViolationNoSession
@@ -88,6 +88,7 @@ const (
TConnackServerUnavailable
TConnackBadUsernamePassword
TConnackBadUsernamePasswordNoSession
TConnackMqtt5BadUsernamePasswordNoSession
TConnackNotAuthorised
TConnackMalSessionPresent
TConnackMalReturnCode
@@ -100,6 +101,7 @@ const (
TPublishBasicMqtt5
TPublishMqtt5
TPublishQos1
TPublishQos1Mqtt5
TPublishQos1NoPayload
TPublishQos1Dup
TPublishQos2
@@ -127,11 +129,14 @@ const (
TPublishSpecDenySysTopic
TPuback
TPubackMqtt5
TPubackMqtt5NotAuthorized
TPubackMalPacketID
TPubackMalProperties
TPubackUnexpectedError
TPubrec
TPubrecMqtt5
TPubrecMqtt5IDInUse
TPubrecMqtt5NotAuthorized
TPubrecMalPacketID
TPubrecMalProperties
TPubrecMalReasonCode
@@ -179,7 +184,6 @@ const (
TUnsubscribe
TUnsubscribeMany
TUnsubscribeMqtt5
TUnsubscribeDropProperties
TUnsubscribeMalPacketID
TUnsubscribeMalTopicName
TUnsubscribeMalProperties
@@ -197,7 +201,6 @@ const (
TDisconnect
TDisconnectTakeover
TDisconnectMqtt5
TDisconnectNormalMqtt5
TDisconnectSecondConnect
TDisconnectReceiveMaximum
TDisconnectDropProperties
@@ -248,26 +251,26 @@ var TPacketData = map[byte]TPacketCases{
Desc: "mqtt v3.1.1",
Primary: true,
RawBytes: []byte{
Connect << 4, 16, // Fixed header
Connect << 4, 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 60, // Keepalive
0, 4, // Client ID - MSB+LSB
'z', 'e', 'n', '3', // Client ID "zen"
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 16,
Remaining: 15,
},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: false,
Keepalive: 60,
ClientIdentifier: "zen3",
ClientIdentifier: "zen",
},
},
},
@@ -424,9 +427,9 @@ var TPacketData = map[byte]TPacketCases{
Connect << 4, 28, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
194, // Packet Flags
0, 20, // Keepalive
4, // Protocol Version
0 | 1<<6 | 1<<7, // Packet Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 5, // Username MSB+LSB
@@ -442,7 +445,7 @@ var TPacketData = map[byte]TPacketCases{
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Clean: false,
Keepalive: 20,
ClientIdentifier: "zen",
UsernameFlag: true,
@@ -496,6 +499,43 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectZeroByteUsername,
Desc: "username flag but 0 byte username",
Group: "decode",
RawBytes: []byte{
Connect << 4, 23, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Packet Flags
0, 30, // Keepalive
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 0, // Username MSB+LSB
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 23,
},
ProtocolVersion: 5,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Keepalive: 30,
ClientIdentifier: "zen",
Username: []byte{},
UsernameFlag: true,
},
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
},
},
},
// Fail States
{
@@ -622,6 +662,24 @@ var TPacketData = map[byte]TPacketCases{
'm', 'o', 'c',
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "username flag with no username bytes",
Group: "decode",
FailFirst: ErrProtocolViolationFlagNoUsername,
RawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Flags
0, 20, // Keepalive
0,
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
},
{
Case: TConnectMalPassword,
Desc: "malformed password",
@@ -782,20 +840,6 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "has username flag but no username",
Group: "validate",
Expect: ErrProtocolViolationFlagNoUsername,
Packet: &Packet{
FixedHeader: FixedHeader{Type: Connect},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
UsernameFlag: true,
},
},
},
{
Case: TConnectInvalidUsernameNoFlag,
Desc: "has username but no flag",
@@ -1042,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted, no session, adjusted expiry interval mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 11, // fixed header
Connack << 4, 8, // fixed header
0, // Session present
CodeSuccess.Code,
8, // length
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
19, 0, 10, // Server Keep Alive (19)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 11,
Remaining: 8,
},
ReasonCode: CodeSuccess.Code,
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
},
},
},
@@ -1147,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 16, // fixed header
Connack << 4, 13, // fixed header
1, // existing session
CodeSuccess.Code,
13, // Properties length
10, // Properties length
18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
19, 0, 20, // Server Keep Alive (19)
36, 1, // Maximum Qos (36)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 16,
Remaining: 13,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
},
},
},
@@ -1177,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5b",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
Connack << 4, 3, // fixed header
0, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
0, // Properties length
},
Packet: &Packet{
ProtocolVersion: 5,
@@ -1191,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{
},
SessionPresent: false,
ReasonCode: CodeSuccess.Code,
},
},
{
Case: TConnackServerKeepalive,
Desc: "server set keepalive",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
1, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 6,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
@@ -1202,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{
Desc: "failure min properties mqtt5",
Primary: true,
RawBytes: append([]byte{
Connack << 4, 26, // fixed header
Connack << 4, 23, // fixed header
0, // No existing session
ErrUnspecifiedError.Code,
// Properties
23, // length
19, 0, 20, // Server Keep Alive (19)
20, // length
31, 0, 17, // Reason String (31)
}, []byte(ErrUnspecifiedError.Reason)...),
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 25,
Remaining: 23,
},
SessionPresent: false,
ReasonCode: ErrUnspecifiedError.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
ReasonString: ErrUnspecifiedError.Reason,
ReasonString: ErrUnspecifiedError.Reason,
},
},
},
@@ -1315,10 +1370,28 @@ var TPacketData = map[byte]TPacketCases{
Desc: "bad username or password no session",
RawBytes: []byte{
Connack << 4, 2, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0, // No session present
Err3NotAuthorized.Code, // use v3 remapping
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
},
ReasonCode: Err3NotAuthorized.Code,
},
},
{
Case: TConnackMqtt5BadUsernamePasswordNoSession,
Desc: "mqtt5 bad username or password no session",
RawBytes: []byte{
Connack << 4, 3, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
@@ -1326,6 +1399,7 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrBadUsernameOrPassword.Code,
},
},
{
Case: TConnackNotAuthorised,
Desc: "not authorised",
@@ -1630,6 +1704,43 @@ var TPacketData = map[byte]TPacketCases{
PacketID: 7,
},
},
{
Case: TPublishQos1Mqtt5,
Desc: "mqtt v5",
Primary: true,
RawBytes: []byte{
Publish<<4 | 1<<1, 37, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
0, 7, // Packet ID - LSB+MSB
// Properties
16, // length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Publish,
Remaining: 37,
Qos: 1,
},
PacketID: 7,
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
},
},
{
Case: TPublishQos1Dup,
Desc: "qos:1, dup:true, packet id",
@@ -1803,13 +1914,10 @@ var TPacketData = map[byte]TPacketCases{
Case: TPublishRetainMqtt5,
Desc: "retain mqtt5",
RawBytes: []byte{
Publish<<4 | 1<<0, 35, // Fixed header
Publish<<4 | 1<<0, 19, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
16, // properties length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
0, // properties length
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
@@ -1817,18 +1925,11 @@ var TPacketData = map[byte]TPacketCases{
FixedHeader: FixedHeader{
Type: Publish,
Retain: true,
Remaining: 35,
Remaining: 19,
},
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
TopicName: "a/b/c",
Properties: Properties{},
Payload: []byte("hello mochi"),
},
},
{
@@ -2171,6 +2272,66 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TPubackMqtt5NotAuthorized,
Desc: "QOS 1 publish not authorized mqtt5",
Primary: true,
RawBytes: []byte{
Puback << 4, 37, // Fixed header
0, 7, // Packet ID - LSB+MSB
ErrNotAuthorized.Code, // Reason Code
33, // Properties Length
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Puback,
Remaining: 31,
},
PacketID: 7,
ReasonCode: ErrNotAuthorized.Code,
Properties: Properties{
ReasonString: ErrNotAuthorized.Reason,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
},
},
{
Case: TPubackUnexpectedError,
Desc: "unexpected error",
Group: "decode",
RawBytes: []byte{
Puback << 4, 29, // Fixed header
0, 7, // Packet ID - LSB+MSB
ErrPayloadFormatInvalid.Code, // Reason Code
25, // Properties Length
31, 0, 22, 'p', 'a', 'y', 'l', 'o', 'a', 'd',
' ', 'f', 'o', 'r', 'm', 'a', 't',
' ', 'i', 'n', 'v', 'a', 'l', 'i', 'd', // Reason String (31)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Puback,
Remaining: 28,
},
PacketID: 7,
ReasonCode: ErrPayloadFormatInvalid.Code,
Properties: Properties{
ReasonString: ErrPayloadFormatInvalid.Reason,
},
},
},
// Fail states
{
@@ -2252,14 +2413,17 @@ var TPacketData = map[byte]TPacketCases{
Desc: "packet id in use mqtt5",
Primary: true,
RawBytes: []byte{
Pubrec << 4, 31, // Fixed header
Pubrec << 4, 47, // Fixed header
0, 7, // Packet ID - LSB+MSB
ErrPacketIdentifierInUse.Code, // Reason Code
27, // Properties Length
43, // Properties Length
31, 0, 24, 'p', 'a', 'c', 'k', 'e', 't',
' ', 'i', 'd', 'e', 'n', 't', 'i', 'f', 'i', 'e', 'r',
' ', 'i', 'n',
' ', 'u', 's', 'e', // Reason String (31)
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
},
Packet: &Packet{
ProtocolVersion: 5,
@@ -2271,6 +2435,46 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrPacketIdentifierInUse.Code,
Properties: Properties{
ReasonString: ErrPacketIdentifierInUse.Reason,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
},
},
{
Case: TPubrecMqtt5NotAuthorized,
Desc: "QOS 2 publish not authorized mqtt5",
Primary: true,
RawBytes: []byte{
Pubrec << 4, 37, // Fixed header
0, 7, // Packet ID - LSB+MSB
ErrNotAuthorized.Code, // Reason Code
33, // Properties Length
31, 0, 14, 'n', 'o', 't', ' ', 'a', 'u',
't', 'h', 'o', 'r', 'i', 'z', 'e', 'd', // Reason String (31)
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Pubrec,
Remaining: 31,
},
PacketID: 7,
ReasonCode: ErrNotAuthorized.Code,
Properties: Properties{
ReasonString: ErrNotAuthorized.Reason,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
},
},

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (

662
server.go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,11 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
@@ -19,6 +22,7 @@ type Info struct {
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
@@ -28,3 +32,30 @@ type Info struct {
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

37
system/system_test.go Normal file
View File

@@ -0,0 +1,37 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

197
topics.go
View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
@@ -8,7 +9,7 @@ import (
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2/packets"
)
var (
@@ -185,6 +186,65 @@ func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscriptio
return m
}
// InlineSubFn is the signature for a callback function which will be called
// when an inline client receives a message on a topic it is subscribed to.
// The sub argument contains information about the subscription that was matched for any filters.
type InlineSubFn func(cl *Client, sub packets.Subscription, pk packets.Packet)
// InlineSubscriptions represents a map of internal subscriptions keyed on client.
type InlineSubscriptions struct {
internal map[int]InlineSubscription
sync.RWMutex
}
// NewInlineSubscriptions returns a new instance of InlineSubscriptions.
func NewInlineSubscriptions() *InlineSubscriptions {
return &InlineSubscriptions{
internal: map[int]InlineSubscription{},
}
}
// Add adds a new internal subscription for a client id.
func (s *InlineSubscriptions) Add(val InlineSubscription) {
s.Lock()
defer s.Unlock()
s.internal[val.Identifier] = val
}
// GetAll returns all internal subscriptions.
func (s *InlineSubscriptions) GetAll() map[int]InlineSubscription {
s.RLock()
defer s.RUnlock()
m := map[int]InlineSubscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns an internal subscription for a client id.
func (s *InlineSubscriptions) Get(id int) (val InlineSubscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of internal subscriptions.
func (s *InlineSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes an internal subscription by the client id.
func (s *InlineSubscriptions) Delete(id int) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions struct {
internal map[string]packets.Subscription
@@ -243,11 +303,17 @@ func (s *Subscriptions) Delete(id string) {
// ClientSubscriptions is a map of aggregated subscriptions for a client.
type ClientSubscriptions map[string]packets.Subscription
type InlineSubscription struct {
packets.Subscription
Handler InlineSubFn
}
// Subscribers contains the shared and non-shared subscribers matching a topic.
type Subscribers struct {
Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription
Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription
InlineSubscriptions map[int]InlineSubscription
}
// SelectShared returns one subscriber for each shared subscription group.
@@ -297,9 +363,45 @@ func NewTopicsIndex() *TopicsIndex {
}
}
// InlineSubscribe adds a new internal subscription for a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) InlineSubscribe(subscription InlineSubscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
n := x.set(subscription.Filter, 0)
_, existed = n.inlineSubscriptions.Get(subscription.Identifier)
n.inlineSubscriptions.Add(subscription)
return !existed
}
// InlineUnsubscribe removes an internal subscription for a topic filter associated with a specific client,
// returning true if the subscription existed.
func (x *TopicsIndex) InlineUnsubscribe(id int, filter string) bool {
x.root.Lock()
defer x.root.Unlock()
particle := x.seek(filter, 0)
if particle == nil {
return false
}
particle.inlineSubscriptions.Delete(id)
if particle.inlineSubscriptions.Len() == 0 {
x.trim(particle)
}
return true
}
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
@@ -319,8 +421,13 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
prefix, _ := isolateParticle(filter, 0)
shareSub := strings.EqualFold(prefix, SharePrefix)
if shareSub {
d = 2
}
@@ -329,8 +436,7 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
return false
}
prefix, _ := isolateParticle(filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
if shareSub {
group, _ := isolateParticle(filter, 1)
particle.shared.Delete(group, client)
} else {
@@ -345,7 +451,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
@@ -360,6 +471,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
@@ -470,9 +582,10 @@ func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []pack
// their subscription ids and highest qos.
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
return x.scanSubscribers(topic, 0, nil, &Subscribers{
Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{},
Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{},
InlineSubscriptions: map[int]InlineSubscription{},
})
}
@@ -487,20 +600,30 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
}
key, hasNext := isolateParticle(topic, d)
for _, partKey := range []string{key, "+", "#"} {
for _, partKey := range []string{key, "+"} {
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
}
if hasNext {
x.scanSubscribers(topic, d+1, particle, subs)
} else {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
x.gatherSharedSubscriptions(wild, subs)
x.gatherInlineSubscriptions(particle, subs)
}
}
}
}
if particle := n.particles.get("#"); particle != nil {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
x.gatherInlineSubscriptions(particle, subs)
}
return subs
}
@@ -541,6 +664,17 @@ func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscr
}
}
// gatherSharedSubscriptions gathers all inline subscriptions for a particle.
func (x *TopicsIndex) gatherInlineSubscriptions(particle *particle, subs *Subscribers) {
if subs.InlineSubscriptions == nil {
subs.InlineSubscriptions = map[int]InlineSubscription{}
}
for id, inline := range particle.inlineSubscriptions.GetAll() {
subs.InlineSubscriptions[id] = inline
}
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
@@ -571,7 +705,7 @@ func IsSharedFilter(filter string) bool {
// IsValidFilter returns true if the filter is valid.
func IsValidFilter(filter string, forPublish bool) bool {
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publish.
return false // [MQTT-4.7.3-1]
}
@@ -612,22 +746,25 @@ func IsValidFilter(filter string, forPublish bool) bool {
// particle is a child node on the tree.
type particle struct {
key string // the key of the particle
parent *particle // a pointer to the parent of the particle
particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
key string // the key of the particle
parent *particle // a pointer to the parent of the particle
particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
inlineSubscriptions *InlineSubscriptions // a map of inline subscriptions for this particle
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.
func newParticle(key string, parent *particle) *particle {
return &particle{
key: key,
parent: parent,
particles: newParticles(),
subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(),
key: key,
parent: parent,
particles: newParticles(),
subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(),
inlineSubscriptions: NewInlineSubscriptions(),
}
}

View File

@@ -1,12 +1,14 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"fmt"
"testing"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/stretchr/testify/require"
)
@@ -318,7 +320,7 @@ func TestUnsubscribeShared(t *testing.T) {
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1"))
_, exists = final.shared.Get("tmp", "cl1")
require.False(t, exists)
}
@@ -500,28 +502,40 @@ func TestScanSubscribers(t *testing.T) {
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 4, len(subs.Subscriptions))
require.Equal(t, 3, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl1")
require.Contains(t, subs.Subscriptions, "cl2")
require.Contains(t, subs.Subscriptions, "cl3")
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers))
require.Equal(t, 1, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
require.Equal(t, 0, len(subs.Subscriptions))
}
func TestScanSubscribersTopicInheritanceBug(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 1, len(subs.Subscriptions))
}
func TestScanSubscribersShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
@@ -530,8 +544,9 @@ func TestScanSubscribersShared(t *testing.T) {
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 3, len(subs.Shared))
require.Equal(t, 4, len(subs.Shared))
}
func TestSelectSharedSubscriber(t *testing.T) {
@@ -839,3 +854,215 @@ func TestNewTopicAliases(t *testing.T) {
require.NotNil(t, a.Outbound)
require.Equal(t, uint16(5), a.Outbound.maximum)
}
func TestNewInlineSubscriptions(t *testing.T) {
subscriptions := NewInlineSubscriptions()
require.NotNil(t, subscriptions)
require.NotNil(t, subscriptions.internal)
require.Equal(t, 0, subscriptions.Len())
}
func TestInlineSubscriptionAdd(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
sub, ok := subscriptions.Get(1)
require.True(t, ok)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
}
func TestInlineSubscriptionGet(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
sub, ok := subscriptions.Get(1)
require.True(t, ok)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, fmt.Sprintf("%p", handler), fmt.Sprintf("%p", sub.Handler))
_, ok = subscriptions.Get(999)
require.False(t, ok)
}
func TestInlineSubscriptionsGetAll(t *testing.T) {
subscriptions := NewInlineSubscriptions()
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2},
})
subscriptions.Add(InlineSubscription{
Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 3},
})
allSubs := subscriptions.GetAll()
require.Len(t, allSubs, 3)
require.Contains(t, allSubs, 1)
require.Contains(t, allSubs, 2)
require.Contains(t, allSubs, 3)
}
func TestInlineSubscriptionDelete(t *testing.T) {
subscriptions := NewInlineSubscriptions()
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
subscription := InlineSubscription{
Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1},
Handler: handler,
}
subscriptions.Add(subscription)
subscriptions.Delete(1)
_, ok := subscriptions.Get(1)
require.False(t, ok)
require.Empty(t, subscriptions.GetAll())
require.Zero(t, subscriptions.Len())
}
func TestInlineSubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
tt := []struct {
desc string
filter string
subscription InlineSubscription
wasNew bool
}{
{
desc: "subscribe",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
wasNew: true,
},
{
desc: "subscribe existed",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 1}},
wasNew: false,
},
{
desc: "subscribe different identifier",
filter: "a/b/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c", Identifier: 2}},
wasNew: true,
},
{
desc: "subscribe case sensitive didnt exist",
filter: "A/B/c",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "A/B/c", Identifier: 1}},
wasNew: true,
},
{
desc: "wildcard+ sub",
filter: "d/+",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/+", Identifier: 1}},
wasNew: true,
},
{
desc: "wildcard# sub",
filter: "d/e/#",
subscription: InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/#", Identifier: 1}},
wasNew: true,
},
}
index := NewTopicsIndex()
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.wasNew, index.InlineSubscribe(tx.subscription))
})
}
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
}
func TestInlineUnsubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
index := NewTopicsIndex()
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
sub, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index = NewTopicsIndex()
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/c/d", Identifier: 1}})
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 2}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(2)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "a/b/+/d", Identifier: 1}})
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "d/e/f", Identifier: 1}})
sub, exists = index.root.particles.get("d").particles.get("e").particles.get("f").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
index.InlineSubscribe(InlineSubscription{Handler: handler, Subscription: packets.Subscription{Filter: "#", Identifier: 1}})
sub, exists = index.root.particles.get("#").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
ok := index.InlineUnsubscribe(1, "a/b/c/d")
require.True(t, ok)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
sub, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").inlineSubscriptions.Get(1)
require.NotNil(t, sub)
require.True(t, exists)
ok = index.InlineUnsubscribe(1, "d/e/f")
require.True(t, ok)
require.NotNil(t, index.root.particles.get("d").particles.get("e").particles.get("f"))
ok = index.InlineUnsubscribe(1, "not/exist")
require.False(t, ok)
}

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Yasuhiro Matsumoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,48 +0,0 @@
# go-colorable
[![Build Status](https://github.com/mattn/go-colorable/workflows/test/badge.svg)](https://github.com/mattn/go-colorable/actions?query=workflow%3Atest)
[![Codecov](https://codecov.io/gh/mattn/go-colorable/branch/master/graph/badge.svg)](https://codecov.io/gh/mattn/go-colorable)
[![GoDoc](https://godoc.org/github.com/mattn/go-colorable?status.svg)](http://godoc.org/github.com/mattn/go-colorable)
[![Go Report Card](https://goreportcard.com/badge/mattn/go-colorable)](https://goreportcard.com/report/mattn/go-colorable)
Colorable writer for windows.
For example, most of logger packages doesn't show colors on windows. (I know we can do it with ansicon. But I don't want.)
This package is possible to handle escape sequence for ansi color on windows.
## Too Bad!
![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/bad.png)
## So Good!
![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/good.png)
## Usage
```go
logrus.SetFormatter(&logrus.TextFormatter{ForceColors: true})
logrus.SetOutput(colorable.NewColorableStdout())
logrus.Info("succeeded")
logrus.Warn("not correct")
logrus.Error("something error")
logrus.Fatal("panic")
```
You can compile above code on non-windows OSs.
## Installation
```
$ go get github.com/mattn/go-colorable
```
# License
MIT
# Author
Yasuhiro Matsumoto (a.k.a mattn)

View File

@@ -1,38 +0,0 @@
//go:build appengine
// +build appengine
package colorable
import (
"io"
"os"
_ "github.com/mattn/go-isatty"
)
// NewColorable returns new instance of Writer which handles escape sequence.
func NewColorable(file *os.File) io.Writer {
if file == nil {
panic("nil passed instead of *os.File to NewColorable()")
}
return file
}
// NewColorableStdout returns new instance of Writer which handles escape sequence for stdout.
func NewColorableStdout() io.Writer {
return os.Stdout
}
// NewColorableStderr returns new instance of Writer which handles escape sequence for stderr.
func NewColorableStderr() io.Writer {
return os.Stderr
}
// EnableColorsStdout enable colors if possible.
func EnableColorsStdout(enabled *bool) func() {
if enabled != nil {
*enabled = true
}
return func() {}
}

View File

@@ -1,38 +0,0 @@
//go:build !windows && !appengine
// +build !windows,!appengine
package colorable
import (
"io"
"os"
_ "github.com/mattn/go-isatty"
)
// NewColorable returns new instance of Writer which handles escape sequence.
func NewColorable(file *os.File) io.Writer {
if file == nil {
panic("nil passed instead of *os.File to NewColorable()")
}
return file
}
// NewColorableStdout returns new instance of Writer which handles escape sequence for stdout.
func NewColorableStdout() io.Writer {
return os.Stdout
}
// NewColorableStderr returns new instance of Writer which handles escape sequence for stderr.
func NewColorableStderr() io.Writer {
return os.Stderr
}
// EnableColorsStdout enable colors if possible.
func EnableColorsStdout(enabled *bool) func() {
if enabled != nil {
*enabled = true
}
return func() {}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +0,0 @@
#!/usr/bin/env bash
set -e
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
done

View File

@@ -1,57 +0,0 @@
package colorable
import (
"bytes"
"io"
)
// NonColorable holds writer but removes escape sequence.
type NonColorable struct {
out io.Writer
}
// NewNonColorable returns new instance of Writer which removes escape sequence from Writer.
func NewNonColorable(w io.Writer) io.Writer {
return &NonColorable{out: w}
}
// Write writes data on console
func (w *NonColorable) Write(data []byte) (n int, err error) {
er := bytes.NewReader(data)
var plaintext bytes.Buffer
loop:
for {
c1, err := er.ReadByte()
if err != nil {
plaintext.WriteTo(w.out)
break loop
}
if c1 != 0x1b {
plaintext.WriteByte(c1)
continue
}
_, err = plaintext.WriteTo(w.out)
if err != nil {
break loop
}
c2, err := er.ReadByte()
if err != nil {
break loop
}
if c2 != 0x5b {
continue
}
for {
c, err := er.ReadByte()
if err != nil {
break loop
}
if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' {
break
}
}
}
return len(data), nil
}

View File

@@ -1,9 +0,0 @@
Copyright (c) Yasuhiro MATSUMOTO <mattn.jp@gmail.com>
MIT License (Expat)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -1,50 +0,0 @@
# go-isatty
[![Godoc Reference](https://godoc.org/github.com/mattn/go-isatty?status.svg)](http://godoc.org/github.com/mattn/go-isatty)
[![Codecov](https://codecov.io/gh/mattn/go-isatty/branch/master/graph/badge.svg)](https://codecov.io/gh/mattn/go-isatty)
[![Coverage Status](https://coveralls.io/repos/github/mattn/go-isatty/badge.svg?branch=master)](https://coveralls.io/github/mattn/go-isatty?branch=master)
[![Go Report Card](https://goreportcard.com/badge/mattn/go-isatty)](https://goreportcard.com/report/mattn/go-isatty)
isatty for golang
## Usage
```go
package main
import (
"fmt"
"github.com/mattn/go-isatty"
"os"
)
func main() {
if isatty.IsTerminal(os.Stdout.Fd()) {
fmt.Println("Is Terminal")
} else if isatty.IsCygwinTerminal(os.Stdout.Fd()) {
fmt.Println("Is Cygwin/MSYS2 Terminal")
} else {
fmt.Println("Is Not Terminal")
}
}
```
## Installation
```
$ go get github.com/mattn/go-isatty
```
## License
MIT
## Author
Yasuhiro Matsumoto (a.k.a mattn)
## Thanks
* k-takata: base idea for IsCygwinTerminal
https://github.com/k-takata/go-iscygpty

View File

@@ -1,2 +0,0 @@
// Package isatty implements interface to isatty
package isatty

View File

@@ -1,12 +0,0 @@
#!/usr/bin/env bash
set -e
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
done

View File

@@ -1,19 +0,0 @@
//go:build (darwin || freebsd || openbsd || netbsd || dragonfly) && !appengine
// +build darwin freebsd openbsd netbsd dragonfly
// +build !appengine
package isatty
import "golang.org/x/sys/unix"
// IsTerminal return true if the file descriptor is terminal.
func IsTerminal(fd uintptr) bool {
_, err := unix.IoctlGetTermios(int(fd), unix.TIOCGETA)
return err == nil
}
// IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool {
return false
}

View File

@@ -1,16 +0,0 @@
//go:build appengine || js || nacl || wasm
// +build appengine js nacl wasm
package isatty
// IsTerminal returns true if the file descriptor is terminal which
// is always false on js and appengine classic which is a sandboxed PaaS.
func IsTerminal(fd uintptr) bool {
return false
}
// IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool {
return false
}

View File

@@ -1,23 +0,0 @@
//go:build plan9
// +build plan9
package isatty
import (
"syscall"
)
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool {
path, err := syscall.Fd2path(int(fd))
if err != nil {
return false
}
return path == "/dev/cons" || path == "/mnt/term/dev/cons"
}
// IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool {
return false
}

View File

@@ -1,21 +0,0 @@
//go:build solaris && !appengine
// +build solaris,!appengine
package isatty
import (
"golang.org/x/sys/unix"
)
// IsTerminal returns true if the given file descriptor is a terminal.
// see: https://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libc/port/gen/isatty.c
func IsTerminal(fd uintptr) bool {
_, err := unix.IoctlGetTermio(int(fd), unix.TCGETA)
return err == nil
}
// IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool {
return false
}

View File

@@ -1,19 +0,0 @@
//go:build (linux || aix || zos) && !appengine
// +build linux aix zos
// +build !appengine
package isatty
import "golang.org/x/sys/unix"
// IsTerminal return true if the file descriptor is terminal.
func IsTerminal(fd uintptr) bool {
_, err := unix.IoctlGetTermios(int(fd), unix.TCGETS)
return err == nil
}
// IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool {
return false
}

View File

@@ -1,125 +0,0 @@
//go:build windows && !appengine
// +build windows,!appengine
package isatty
import (
"errors"
"strings"
"syscall"
"unicode/utf16"
"unsafe"
)
const (
objectNameInfo uintptr = 1
fileNameInfo = 2
fileTypePipe = 3
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
ntdll = syscall.NewLazyDLL("ntdll.dll")
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procGetFileInformationByHandleEx = kernel32.NewProc("GetFileInformationByHandleEx")
procGetFileType = kernel32.NewProc("GetFileType")
procNtQueryObject = ntdll.NewProc("NtQueryObject")
)
func init() {
// Check if GetFileInformationByHandleEx is available.
if procGetFileInformationByHandleEx.Find() != nil {
procGetFileInformationByHandleEx = nil
}
}
// IsTerminal return true if the file descriptor is terminal.
func IsTerminal(fd uintptr) bool {
var st uint32
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, fd, uintptr(unsafe.Pointer(&st)), 0)
return r != 0 && e == 0
}
// Check pipe name is used for cygwin/msys2 pty.
// Cygwin/MSYS2 PTY has a name like:
// \{cygwin,msys}-XXXXXXXXXXXXXXXX-ptyN-{from,to}-master
func isCygwinPipeName(name string) bool {
token := strings.Split(name, "-")
if len(token) < 5 {
return false
}
if token[0] != `\msys` &&
token[0] != `\cygwin` &&
token[0] != `\Device\NamedPipe\msys` &&
token[0] != `\Device\NamedPipe\cygwin` {
return false
}
if token[1] == "" {
return false
}
if !strings.HasPrefix(token[2], "pty") {
return false
}
if token[3] != `from` && token[3] != `to` {
return false
}
if token[4] != "master" {
return false
}
return true
}
// getFileNameByHandle use the undocomented ntdll NtQueryObject to get file full name from file handler
// since GetFileInformationByHandleEx is not available under windows Vista and still some old fashion
// guys are using Windows XP, this is a workaround for those guys, it will also work on system from
// Windows vista to 10
// see https://stackoverflow.com/a/18792477 for details
func getFileNameByHandle(fd uintptr) (string, error) {
if procNtQueryObject == nil {
return "", errors.New("ntdll.dll: NtQueryObject not supported")
}
var buf [4 + syscall.MAX_PATH]uint16
var result int
r, _, e := syscall.Syscall6(procNtQueryObject.Addr(), 5,
fd, objectNameInfo, uintptr(unsafe.Pointer(&buf)), uintptr(2*len(buf)), uintptr(unsafe.Pointer(&result)), 0)
if r != 0 {
return "", e
}
return string(utf16.Decode(buf[4 : 4+buf[0]/2])), nil
}
// IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2
// terminal.
func IsCygwinTerminal(fd uintptr) bool {
if procGetFileInformationByHandleEx == nil {
name, err := getFileNameByHandle(fd)
if err != nil {
return false
}
return isCygwinPipeName(name)
}
// Cygwin/msys's pty is a pipe.
ft, _, e := syscall.Syscall(procGetFileType.Addr(), 1, fd, 0, 0)
if ft != fileTypePipe || e != 0 {
return false
}
var buf [2 + syscall.MAX_PATH]uint16
r, _, e := syscall.Syscall6(procGetFileInformationByHandleEx.Addr(),
4, fd, fileNameInfo, uintptr(unsafe.Pointer(&buf)),
uintptr(len(buf)*2), 0, 0)
if r == 0 || e != 0 {
return false
}
l := *(*uint32)(unsafe.Pointer(&buf))
return isCygwinPipeName(string(utf16.Decode(buf[2 : 2+l/2])))
}

View File

@@ -1,25 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
tmp
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof

1
vendor/github.com/rs/zerolog/CNAME generated vendored
View File

@@ -1 +0,0 @@
zerolog.io

21
vendor/github.com/rs/zerolog/LICENSE generated vendored
View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2017 Olivier Poitrey
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,716 +0,0 @@
# Zero Allocation JSON Logger
[![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/zerolog) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/zerolog/master/LICENSE) [![Build Status](https://travis-ci.org/rs/zerolog.svg?branch=master)](https://travis-ci.org/rs/zerolog) [![Coverage](http://gocover.io/_badge/github.com/rs/zerolog)](http://gocover.io/github.com/rs/zerolog)
The zerolog package provides a fast and simple logger dedicated to JSON output.
Zerolog's API is designed to provide both a great developer experience and stunning [performance](#benchmarks). Its unique chaining API allows zerolog to write JSON (or CBOR) log events by avoiding allocations and reflection.
Uber's [zap](https://godoc.org/go.uber.org/zap) library pioneered this approach. Zerolog is taking this concept to the next level with a simpler to use API and even better performance.
To keep the code base and the API simple, zerolog focuses on efficient structured logging only. Pretty logging on the console is made possible using the provided (but inefficient) [`zerolog.ConsoleWriter`](#pretty-logging).
![Pretty Logging Image](pretty.png)
## Who uses zerolog
Find out [who uses zerolog](https://github.com/rs/zerolog/wiki/Who-uses-zerolog) and add your company / project to the list.
## Features
* [Blazing fast](#benchmarks)
* [Low to zero allocation](#benchmarks)
* [Leveled logging](#leveled-logging)
* [Sampling](#log-sampling)
* [Hooks](#hooks)
* [Contextual fields](#contextual-logging)
* `context.Context` integration
* [Integration with `net/http`](#integration-with-nethttp)
* [JSON and CBOR encoding formats](#binary-encoding)
* [Pretty logging for development](#pretty-logging)
* [Error Logging (with optional Stacktrace)](#error-logging)
## Installation
```bash
go get -u github.com/rs/zerolog/log
```
## Getting Started
### Simple Logging Example
For simple logging, import the global logger package **github.com/rs/zerolog/log**
```go
package main
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
// UNIX Time is faster and smaller than most timestamps
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Print("hello world")
}
// Output: {"time":1516134303,"level":"debug","message":"hello world"}
```
> Note: By default log writes to `os.Stderr`
> Note: The default log level for `log.Print` is *debug*
### Contextual Logging
**zerolog** allows data to be added to log messages in the form of key:value pairs. The data added to the message adds "context" about the log event that can be critical for debugging as well as myriad other purposes. An example of this is below:
```go
package main
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Debug().
Str("Scale", "833 cents").
Float64("Interval", 833.09).
Msg("Fibonacci is everywhere")
log.Debug().
Str("Name", "Tom").
Send()
}
// Output: {"level":"debug","Scale":"833 cents","Interval":833.09,"time":1562212768,"message":"Fibonacci is everywhere"}
// Output: {"level":"debug","Name":"Tom","time":1562212768}
```
> You'll note in the above example that when adding contextual fields, the fields are strongly typed. You can find the full list of supported fields [here](#standard-types)
### Leveled Logging
#### Simple Leveled Logging Example
```go
package main
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Info().Msg("hello world")
}
// Output: {"time":1516134303,"level":"info","message":"hello world"}
```
> It is very important to note that when using the **zerolog** chaining API, as shown above (`log.Info().Msg("hello world"`), the chain must have either the `Msg` or `Msgf` method call. If you forget to add either of these, the log will not occur and there is no compile time error to alert you of this.
**zerolog** allows for logging at the following levels (from highest to lowest):
* panic (`zerolog.PanicLevel`, 5)
* fatal (`zerolog.FatalLevel`, 4)
* error (`zerolog.ErrorLevel`, 3)
* warn (`zerolog.WarnLevel`, 2)
* info (`zerolog.InfoLevel`, 1)
* debug (`zerolog.DebugLevel`, 0)
* trace (`zerolog.TraceLevel`, -1)
You can set the Global logging level to any of these options using the `SetGlobalLevel` function in the zerolog package, passing in one of the given constants above, e.g. `zerolog.InfoLevel` would be the "info" level. Whichever level is chosen, all logs with a level greater than or equal to that level will be written. To turn off logging entirely, pass the `zerolog.Disabled` constant.
#### Setting Global Log Level
This example uses command-line flags to demonstrate various outputs depending on the chosen log level.
```go
package main
import (
"flag"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
debug := flag.Bool("debug", false, "sets log level to debug")
flag.Parse()
// Default level for this example is info, unless debug flag is present
zerolog.SetGlobalLevel(zerolog.InfoLevel)
if *debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
log.Debug().Msg("This message appears only when log level set to Debug")
log.Info().Msg("This message appears when log level set to Debug or Info")
if e := log.Debug(); e.Enabled() {
// Compute log output only if enabled.
value := "bar"
e.Str("foo", value).Msg("some debug message")
}
}
```
Info Output (no flag)
```bash
$ ./logLevelExample
{"time":1516387492,"level":"info","message":"This message appears when log level set to Debug or Info"}
```
Debug Output (debug flag set)
```bash
$ ./logLevelExample -debug
{"time":1516387573,"level":"debug","message":"This message appears only when log level set to Debug"}
{"time":1516387573,"level":"info","message":"This message appears when log level set to Debug or Info"}
{"time":1516387573,"level":"debug","foo":"bar","message":"some debug message"}
```
#### Logging without Level or Message
You may choose to log without a specific level by using the `Log` method. You may also write without a message by setting an empty string in the `msg string` parameter of the `Msg` method. Both are demonstrated in the example below.
```go
package main
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Log().
Str("foo", "bar").
Msg("")
}
// Output: {"time":1494567715,"foo":"bar"}
```
### Error Logging
You can log errors using the `Err` method
```go
package main
import (
"errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
err := errors.New("seems we have an error here")
log.Error().Err(err).Msg("")
}
// Output: {"level":"error","error":"seems we have an error here","time":1609085256}
```
> The default field name for errors is `error`, you can change this by setting `zerolog.ErrorFieldName` to meet your needs.
#### Error Logging with Stacktrace
Using `github.com/pkg/errors`, you can add a formatted stacktrace to your errors.
```go
package main
import (
"github.com/pkg/errors"
"github.com/rs/zerolog/pkgerrors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
err := outer()
log.Error().Stack().Err(err).Msg("")
}
func inner() error {
return errors.New("seems we have an error here")
}
func middle() error {
err := inner()
if err != nil {
return err
}
return nil
}
func outer() error {
err := middle()
if err != nil {
return err
}
return nil
}
// Output: {"level":"error","stack":[{"func":"inner","line":"20","source":"errors.go"},{"func":"middle","line":"24","source":"errors.go"},{"func":"outer","line":"32","source":"errors.go"},{"func":"main","line":"15","source":"errors.go"},{"func":"main","line":"204","source":"proc.go"},{"func":"goexit","line":"1374","source":"asm_amd64.s"}],"error":"seems we have an error here","time":1609086683}
```
> zerolog.ErrorStackMarshaler must be set in order for the stack to output anything.
#### Logging Fatal Messages
```go
package main
import (
"errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func main() {
err := errors.New("A repo man spends his life getting into tense situations")
service := "myservice"
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Fatal().
Err(err).
Str("service", service).
Msgf("Cannot start %s", service)
}
// Output: {"time":1516133263,"level":"fatal","error":"A repo man spends his life getting into tense situations","service":"myservice","message":"Cannot start myservice"}
// exit status 1
```
> NOTE: Using `Msgf` generates one allocation even when the logger is disabled.
### Create logger instance to manage different outputs
```go
logger := zerolog.New(os.Stderr).With().Timestamp().Logger()
logger.Info().Str("foo", "bar").Msg("hello world")
// Output: {"level":"info","time":1494567715,"message":"hello world","foo":"bar"}
```
### Sub-loggers let you chain loggers with additional context
```go
sublogger := log.With().
Str("component", "foo").
Logger()
sublogger.Info().Msg("hello world")
// Output: {"level":"info","time":1494567715,"message":"hello world","component":"foo"}
```
### Pretty logging
To log a human-friendly, colorized output, use `zerolog.ConsoleWriter`:
```go
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
log.Info().Str("foo", "bar").Msg("Hello world")
// Output: 3:04PM INF Hello World foo=bar
```
To customize the configuration and formatting:
```go
output := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}
output.FormatLevel = func(i interface{}) string {
return strings.ToUpper(fmt.Sprintf("| %-6s|", i))
}
output.FormatMessage = func(i interface{}) string {
return fmt.Sprintf("***%s****", i)
}
output.FormatFieldName = func(i interface{}) string {
return fmt.Sprintf("%s:", i)
}
output.FormatFieldValue = func(i interface{}) string {
return strings.ToUpper(fmt.Sprintf("%s", i))
}
log := zerolog.New(output).With().Timestamp().Logger()
log.Info().Str("foo", "bar").Msg("Hello World")
// Output: 2006-01-02T15:04:05Z07:00 | INFO | ***Hello World**** foo:BAR
```
### Sub dictionary
```go
log.Info().
Str("foo", "bar").
Dict("dict", zerolog.Dict().
Str("bar", "baz").
Int("n", 1),
).Msg("hello world")
// Output: {"level":"info","time":1494567715,"foo":"bar","dict":{"bar":"baz","n":1},"message":"hello world"}
```
### Customize automatic field names
```go
zerolog.TimestampFieldName = "t"
zerolog.LevelFieldName = "l"
zerolog.MessageFieldName = "m"
log.Info().Msg("hello world")
// Output: {"l":"info","t":1494567715,"m":"hello world"}
```
### Add contextual fields to the global logger
```go
log.Logger = log.With().Str("foo", "bar").Logger()
```
### Add file and line number to log
Equivalent of `Llongfile`:
```go
log.Logger = log.With().Caller().Logger()
log.Info().Msg("hello world")
// Output: {"level": "info", "message": "hello world", "caller": "/go/src/your_project/some_file:21"}
```
Equivalent of `Lshortfile`:
```go
zerolog.CallerMarshalFunc = func(file string, line int) string {
short := file
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
short = file[i+1:]
break
}
}
file = short
return file + ":" + strconv.Itoa(line)
}
log.Logger = log.With().Caller().Logger()
log.Info().Msg("hello world")
// Output: {"level": "info", "message": "hello world", "caller": "some_file:21"}
```
### Thread-safe, lock-free, non-blocking writer
If your writer might be slow or not thread-safe and you need your log producers to never get slowed down by a slow writer, you can use a `diode.Writer` as follows:
```go
wr := diode.NewWriter(os.Stdout, 1000, 10*time.Millisecond, func(missed int) {
fmt.Printf("Logger Dropped %d messages", missed)
})
log := zerolog.New(wr)
log.Print("test")
```
You will need to install `code.cloudfoundry.org/go-diodes` to use this feature.
### Log Sampling
```go
sampled := log.Sample(&zerolog.BasicSampler{N: 10})
sampled.Info().Msg("will be logged every 10 messages")
// Output: {"time":1494567715,"level":"info","message":"will be logged every 10 messages"}
```
More advanced sampling:
```go
// Will let 5 debug messages per period of 1 second.
// Over 5 debug message, 1 every 100 debug messages are logged.
// Other levels are not sampled.
sampled := log.Sample(zerolog.LevelSampler{
DebugSampler: &zerolog.BurstSampler{
Burst: 5,
Period: 1*time.Second,
NextSampler: &zerolog.BasicSampler{N: 100},
},
})
sampled.Debug().Msg("hello world")
// Output: {"time":1494567715,"level":"debug","message":"hello world"}
```
### Hooks
```go
type SeverityHook struct{}
func (h SeverityHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
if level != zerolog.NoLevel {
e.Str("severity", level.String())
}
}
hooked := log.Hook(SeverityHook{})
hooked.Warn().Msg("")
// Output: {"level":"warn","severity":"warn"}
```
### Pass a sub-logger by context
```go
ctx := log.With().Str("component", "module").Logger().WithContext(ctx)
log.Ctx(ctx).Info().Msg("hello world")
// Output: {"component":"module","level":"info","message":"hello world"}
```
### Set as standard logger output
```go
log := zerolog.New(os.Stdout).With().
Str("foo", "bar").
Logger()
stdlog.SetFlags(0)
stdlog.SetOutput(log)
stdlog.Print("hello world")
// Output: {"foo":"bar","message":"hello world"}
```
### Integration with `net/http`
The `github.com/rs/zerolog/hlog` package provides some helpers to integrate zerolog with `http.Handler`.
In this example we use [alice](https://github.com/justinas/alice) to install logger for better readability.
```go
log := zerolog.New(os.Stdout).With().
Timestamp().
Str("role", "my-service").
Str("host", host).
Logger()
c := alice.New()
// Install the logger handler with default output on the console
c = c.Append(hlog.NewHandler(log))
// Install some provided extra handler to set some request's context fields.
// Thanks to that handler, all our logs will come with some prepopulated fields.
c = c.Append(hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
hlog.FromRequest(r).Info().
Str("method", r.Method).
Stringer("url", r.URL).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Msg("")
}))
c = c.Append(hlog.RemoteAddrHandler("ip"))
c = c.Append(hlog.UserAgentHandler("user_agent"))
c = c.Append(hlog.RefererHandler("referer"))
c = c.Append(hlog.RequestIDHandler("req_id", "Request-Id"))
// Here is your final handler
h := c.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the logger from the request's context. You can safely assume it
// will be always there: if the handler is removed, hlog.FromRequest
// will return a no-op logger.
hlog.FromRequest(r).Info().
Str("user", "current user").
Str("status", "ok").
Msg("Something happened")
// Output: {"level":"info","time":"2001-02-03T04:05:06Z","role":"my-service","host":"local-hostname","req_id":"b4g0l5t6tfid6dtrapu0","user":"current user","status":"ok","message":"Something happened"}
}))
http.Handle("/", h)
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal().Err(err).Msg("Startup failed")
}
```
## Multiple Log Output
`zerolog.MultiLevelWriter` may be used to send the log message to multiple outputs.
In this example, we send the log message to both `os.Stdout` and the in-built ConsoleWriter.
```go
func main() {
consoleWriter := zerolog.ConsoleWriter{Out: os.Stdout}
multi := zerolog.MultiLevelWriter(consoleWriter, os.Stdout)
logger := zerolog.New(multi).With().Timestamp().Logger()
logger.Info().Msg("Hello World!")
}
// Output (Line 1: Console; Line 2: Stdout)
// 12:36PM INF Hello World!
// {"level":"info","time":"2019-11-07T12:36:38+03:00","message":"Hello World!"}
```
## Global Settings
Some settings can be changed and will be applied to all loggers:
* `log.Logger`: You can set this value to customize the global logger (the one used by package level methods).
* `zerolog.SetGlobalLevel`: Can raise the minimum level of all loggers. Call this with `zerolog.Disabled` to disable logging altogether (quiet mode).
* `zerolog.DisableSampling`: If argument is `true`, all sampled loggers will stop sampling and issue 100% of their log events.
* `zerolog.TimestampFieldName`: Can be set to customize `Timestamp` field name.
* `zerolog.LevelFieldName`: Can be set to customize level field name.
* `zerolog.MessageFieldName`: Can be set to customize message field name.
* `zerolog.ErrorFieldName`: Can be set to customize `Err` field name.
* `zerolog.TimeFieldFormat`: Can be set to customize `Time` field value formatting. If set with `zerolog.TimeFormatUnix`, `zerolog.TimeFormatUnixMs` or `zerolog.TimeFormatUnixMicro`, times are formated as UNIX timestamp.
* `zerolog.DurationFieldUnit`: Can be set to customize the unit for time.Duration type fields added by `Dur` (default: `time.Millisecond`).
* `zerolog.DurationFieldInteger`: If set to `true`, `Dur` fields are formatted as integers instead of floats (default: `false`).
* `zerolog.ErrorHandler`: Called whenever zerolog fails to write an event on its output. If not set, an error is printed on the stderr. This handler must be thread safe and non-blocking.
## Field Types
### Standard Types
* `Str`
* `Bool`
* `Int`, `Int8`, `Int16`, `Int32`, `Int64`
* `Uint`, `Uint8`, `Uint16`, `Uint32`, `Uint64`
* `Float32`, `Float64`
### Advanced Fields
* `Err`: Takes an `error` and renders it as a string using the `zerolog.ErrorFieldName` field name.
* `Func`: Run a `func` only if the level is enabled.
* `Timestamp`: Inserts a timestamp field with `zerolog.TimestampFieldName` field name, formatted using `zerolog.TimeFieldFormat`.
* `Time`: Adds a field with time formatted with `zerolog.TimeFieldFormat`.
* `Dur`: Adds a field with `time.Duration`.
* `Dict`: Adds a sub-key/value as a field of the event.
* `RawJSON`: Adds a field with an already encoded JSON (`[]byte`)
* `Hex`: Adds a field with value formatted as a hexadecimal string (`[]byte`)
* `Interface`: Uses reflection to marshal the type.
Most fields are also available in the slice format (`Strs` for `[]string`, `Errs` for `[]error` etc.)
## Binary Encoding
In addition to the default JSON encoding, `zerolog` can produce binary logs using [CBOR](https://cbor.io) encoding. The choice of encoding can be decided at compile time using the build tag `binary_log` as follows:
```bash
go build -tags binary_log .
```
To Decode binary encoded log files you can use any CBOR decoder. One has been tested to work
with zerolog library is [CSD](https://github.com/toravir/csd/).
## Related Projects
* [grpc-zerolog](https://github.com/cheapRoc/grpc-zerolog): Implementation of `grpclog.LoggerV2` interface using `zerolog`
* [overlog](https://github.com/Trendyol/overlog): Implementation of `Mapped Diagnostic Context` interface using `zerolog`
* [zerologr](https://github.com/go-logr/zerologr): Implementation of `logr.LogSink` interface using `zerolog`
## Benchmarks
See [logbench](http://hackemist.com/logbench/) for more comprehensive and up-to-date benchmarks.
All operations are allocation free (those numbers *include* JSON encoding):
```text
BenchmarkLogEmpty-8 100000000 19.1 ns/op 0 B/op 0 allocs/op
BenchmarkDisabled-8 500000000 4.07 ns/op 0 B/op 0 allocs/op
BenchmarkInfo-8 30000000 42.5 ns/op 0 B/op 0 allocs/op
BenchmarkContextFields-8 30000000 44.9 ns/op 0 B/op 0 allocs/op
BenchmarkLogFields-8 10000000 184 ns/op 0 B/op 0 allocs/op
```
There are a few Go logging benchmarks and comparisons that include zerolog.
* [imkira/go-loggers-bench](https://github.com/imkira/go-loggers-bench)
* [uber-common/zap](https://github.com/uber-go/zap#performance)
Using Uber's zap comparison benchmark:
Log a message and 10 fields:
| Library | Time | Bytes Allocated | Objects Allocated |
| :--- | :---: | :---: | :---: |
| zerolog | 767 ns/op | 552 B/op | 6 allocs/op |
| :zap: zap | 848 ns/op | 704 B/op | 2 allocs/op |
| :zap: zap (sugared) | 1363 ns/op | 1610 B/op | 20 allocs/op |
| go-kit | 3614 ns/op | 2895 B/op | 66 allocs/op |
| lion | 5392 ns/op | 5807 B/op | 63 allocs/op |
| logrus | 5661 ns/op | 6092 B/op | 78 allocs/op |
| apex/log | 15332 ns/op | 3832 B/op | 65 allocs/op |
| log15 | 20657 ns/op | 5632 B/op | 93 allocs/op |
Log a message with a logger that already has 10 fields of context:
| Library | Time | Bytes Allocated | Objects Allocated |
| :--- | :---: | :---: | :---: |
| zerolog | 52 ns/op | 0 B/op | 0 allocs/op |
| :zap: zap | 283 ns/op | 0 B/op | 0 allocs/op |
| :zap: zap (sugared) | 337 ns/op | 80 B/op | 2 allocs/op |
| lion | 2702 ns/op | 4074 B/op | 38 allocs/op |
| go-kit | 3378 ns/op | 3046 B/op | 52 allocs/op |
| logrus | 4309 ns/op | 4564 B/op | 63 allocs/op |
| apex/log | 13456 ns/op | 2898 B/op | 51 allocs/op |
| log15 | 14179 ns/op | 2642 B/op | 44 allocs/op |
Log a static string, without any context or `printf`-style templating:
| Library | Time | Bytes Allocated | Objects Allocated |
| :--- | :---: | :---: | :---: |
| zerolog | 50 ns/op | 0 B/op | 0 allocs/op |
| :zap: zap | 236 ns/op | 0 B/op | 0 allocs/op |
| standard library | 453 ns/op | 80 B/op | 2 allocs/op |
| :zap: zap (sugared) | 337 ns/op | 80 B/op | 2 allocs/op |
| go-kit | 508 ns/op | 656 B/op | 13 allocs/op |
| lion | 771 ns/op | 1224 B/op | 10 allocs/op |
| logrus | 1244 ns/op | 1505 B/op | 27 allocs/op |
| apex/log | 2751 ns/op | 584 B/op | 11 allocs/op |
| log15 | 5181 ns/op | 1592 B/op | 26 allocs/op |
## Caveats
Note that zerolog does no de-duplication of fields. Using the same key multiple times creates multiple keys in final JSON:
```go
logger := zerolog.New(os.Stderr).With().Timestamp().Logger()
logger.Info().
Timestamp().
Msg("dup")
// Output: {"level":"info","time":1494567715,"time":1494567715,"message":"dup"}
```
In this case, many consumers will take the last value, but this is not guaranteed; check yours if in doubt.

View File

@@ -1 +0,0 @@
remote_theme: rs/gh-readme

Some files were not shown because too many files have changed in this diff Show More