Compare commits

...

102 Commits

Author SHA1 Message Date
mochi-co
0648e39507 Update Readme 2023-06-19 10:28:22 +01:00
mochi-co
233a82e448 Add Healthcheck listener 2023-06-19 10:22:55 +01:00
JB
51a8d8cb54 Update README.md 2023-06-19 10:15:37 +01:00
mochi-co
23c3208310 Update SPDX annotations 2023-06-19 10:14:07 +01:00
mochi-co
23e1092cda Update Contribution Guidelines 2023-06-19 10:13:43 +01:00
mochi-co
d498576927 Update server version 2023-06-19 09:51:41 +01:00
Derek Duncan
7e14ce99b5 Add healthcheck listener (#244)
* Add healthcheck listener

* Update improper comments

---------

Co-authored-by: Derek Duncan <derekduncan@gmail.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-06-19 09:44:03 +01:00
thedevop
4db49a4b9d Fix ScanSubscribersTopicInheritanceBug (#243)
* Sub a/b should not receive msg for a/b/c...

* Add TestScanSubscribersTopicInheritanceBug test

* Ensure sharedSubscription are gathered

* Fix Unsubscribe for sharedSub and optimization

* Unsub with lower case in TestUnsubscribeShared

* Add test with # for TestScanSubscribersShared
2023-06-19 09:42:16 +01:00
mochi-co
e60b8ff0c9 Update Hooks List 2023-06-14 20:12:52 +01:00
mochi-co
b9d5dcb5f0 Update Server Version 2023-06-14 20:12:31 +01:00
JB
6d394d1fe9 Expose SendConnack, err return on OnConnect (#240) 2023-06-14 20:05:02 +01:00
thedevop
1ee2158711 Add OnRetainPublished hook (#237)
* Add OnRetainPublished hook

* Skip OnRetainPublished if publish error
2023-06-13 19:24:04 +01:00
mochi-co
af79b55b9f Update server version 2023-06-04 07:32:34 +01:00
Derek Duncan
e1a9497c25 Add retainMessage to LWT to properly handle message retention (#234)
* Add retainMessage to LWT to properly handle message retention if specified in connect

* Add will retain flag on missed test

---------

Co-authored-by: Derek Duncan <derekduncan@gmail.com>
2023-06-04 07:31:55 +01:00
mochi-co
62659e17ba Update server version 2023-05-18 20:29:56 +01:00
Hector Oliveros
7ad6dd8e1a Now when a "publish" command fails, then the publish method will throw an error (#229)
Errors in the hook when doing a publish were ignored. This caused that test cases could not be made where the publish failed and an error was thrown.

Co-authored-by: hector.oliveros@wabtec.com <hectoroliveros@MacBook-Pro-de-Hector.local>
2023-05-18 20:14:50 +01:00
thedevop
565e07747e Minimize client lock duration (#223)
* Minimize client lock duration
* Fix server option example
2023-05-18 20:01:35 +01:00
plourdedominic
6acd775a6b Fix example usage of NewHTTPStats (#231)
Co-authored-by: Dominic Plourde <plourded@amotus.ca>
2023-05-18 16:14:29 +01:00
JB
493f6c8bb0 Update README.md new benchmarks 2023-05-15 20:01:12 +01:00
mochi-co
d3785c2717 update server version 2023-05-08 11:43:46 +01:00
thedevop
52a347169a Use context to exit WriteLoop (#222)
* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Use context to exit WriteLoop

* Fix misspelling
2023-05-08 11:30:44 +01:00
mochi-co
797d75cb34 update server version 2023-05-06 14:32:42 +01:00
JB
5225a357e5 refactor server keepalive for hook access (#220) 2023-05-06 14:11:54 +01:00
JB
a734a0dc73 Use context to signal client open state (#218) 2023-05-06 11:55:40 +01:00
JB
6704cf7227 Add packet ID exhausted hook (#217) 2023-05-06 10:37:27 +01:00
thedevop
9233e6fd39 Expire session if SessionExpiryInterval is 0 (#216)
If SessionExpiryInterval was not set in CONNECT, SessionExpiryIntervalFlag is also not set. According to spec:
  If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, the Session ends when the Network Connection is closed.
2023-05-06 10:12:33 +01:00
ħþ
1ca65d9631 Update codes.go (#215)
fix typo
2023-05-06 10:02:25 +01:00
ħþ
33229da885 Update codes.go (#214)
Fix typo
2023-05-06 09:59:50 +01:00
mochi-co
c274d5fd08 Update server version 2023-05-05 00:04:27 +01:00
JB
10e82f41d6 Lock on close outbound (#213) 2023-05-05 00:02:49 +01:00
JB
e6c07b2b78 Add lock to client writes (#212) 2023-05-04 23:17:12 +01:00
JB
eed3ef9606 Add OnPacketIDExhausted hook (#211) 2023-05-04 22:51:40 +01:00
JB
1ec880844d Correctly validate WillProperties (#210)
Co-authored-by: sukvojte <sukvojte@gmail.com>
2023-05-04 22:37:23 +01:00
werbenhu
4b49652a8c Update build.yml (#203) 2023-05-04 18:09:58 +01:00
mochi-co
d46e7b5bcf Protect close of nil outbound channel 2023-04-21 22:09:59 +01:00
mochi-co
17fb7dadbc Protect close of nil outbound channel 2023-04-21 22:00:27 +01:00
werben
ed7fd836e1 #78 storage hook should not execute the relevant code if the client has been reconnected (#198)
* storage hook should not execute the relevant code if the client has been reconnected #78

* add test cases for coverage decrease

add test cases for coverage decrease
2023-04-21 21:52:44 +01:00
mochi-co
605bb93c75 Move msgToPacket to storage.Message.ToPacket 2023-04-21 21:49:49 +01:00
Wind
c73ace2ea0 Simplified code (#195)
Simplify the code for the loadInflight and loadRetained methods.
Adjust the validation order of the processSubscribe method to ensure that it fails quickly if there is an error, since s.hooks. OnACLCheck generally takes a long time.

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-04-21 21:29:03 +01:00
thedevop
aac6d699da Ensure to close client WriteLoop (#193)
* Ensure client WriteLoop is closed

* Ensure to close client WriteLoop
2023-04-21 21:20:46 +01:00
Hubertus Hohl
7bd7bd5087 fix: common subscriptions issued by different clients at the same time may be lost (#186) 2023-03-11 23:17:10 +00:00
mochi-co
655bf9fdb1 Update readme 2023-03-11 23:15:51 +00:00
mochi-co
b188055c7d Update server version 2023-03-11 23:14:51 +00:00
JB
aaf1d9d4c6 Configurable client bufio reader/writer sizes (#190) 2023-03-11 23:13:28 +00:00
mochi-co
44ce819318 Update server version 2023-02-28 20:58:05 +00:00
dependabot[bot]
e4c76cc60c Bump golang.org/x/net from 0.0.0-20220927171203-f486391704dc to 0.7.0 (#182)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.0.0-20220927171203-f486391704dc to 0.7.0.
- [Release notes](https://github.com/golang/net/releases)
- [Commits](https://github.com/golang/net/commits/v0.7.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-02-28 20:57:33 +00:00
thedevop
da79faa972 Skip expire cleanup for isTakenOver session (#183)
* Skip expire cleanup for isTakenOver session

* Set prev connection to isTakenOver on CleanSession

#173.
2023-02-28 20:53:26 +00:00
JB
46babc89c8 Allow 0 byte usernames if correctly formed (#181)
* Allow 0 byte usernames if correctly formed

* Allow 0 byte usernames if correctly formed
2023-02-25 01:37:54 +00:00
JB
9b7a943888 Correctly identify and clean taken-over sessions (#180) 2023-02-25 01:24:17 +00:00
mochi-co
a909d30923 Small style fix 2023-02-22 23:33:49 +00:00
mochi-co
0851b09e4d Update server version 2023-02-22 23:06:22 +00:00
thedevop
a302c9dd88 Use *packets.Packet for outbound chan (#176)
* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan

* Use *packets.Packet for outbound chan
2023-02-22 23:02:44 +00:00
mochi-co
1e8f922102 update server version 2023-02-20 18:14:57 +00:00
Hubertus Hohl
4c16e5593f fix: correct decoding of packets including Properties exceeding 127 bytes in length (#172) 2023-02-20 18:14:19 +00:00
mochi-co
49cada4fbc Update server version 2023-02-10 23:39:27 +00:00
JB
ef34510c0b Expose dropped publish messages count in sys info (#170) 2023-02-10 23:38:20 +00:00
JB
e5716caad1 Fix potential NextPacketID endless loop, expand tests (#169)
* Fix possible NextPacketID endless loop, expand tests

* Optimize NextPacketID

* Use math constants
2023-02-10 23:27:21 +00:00
thedevop
4b039cb35c Add PublishDropped metrics (#167)
* Add PublishDropped

* Add PublishDropped

* Add PublishDropped

* Update storage_test.go

* Update system.go

* Update server.go
2023-02-10 14:44:01 +00:00
JB
aac245441a No longer issue retained messages on session takeover (#166) 2023-02-09 23:57:24 +00:00
JB
bb54cc68e6 Client write buffers (#165)
* Replace fanpool with client write buffers
2023-02-09 22:34:30 +00:00
thedevop
7ba1352a60 Add Clone to system.Info (#163)
* Add Clone using atomic operations

* Add Clone using atomic operations

* Use sysinfo.Clone

* Unit test for Clone

* Add Clone using atomic operations

* Update

* Update
2023-02-09 19:07:17 +00:00
mochi-co
ca849131eb Update server version 2023-02-05 11:07:07 +00:00
Wind
ba7e534122 failed to delete inflight data (#162)
The s.hooks.OnQosPublish method needs to be called, otherwise the following s.hooks.OnQosComplete or processPuback(s.hooks.OnQosComplete) method will report a data not found error.
2023-02-05 10:53:49 +00:00
mochi-co
db760c34a5 Update server version 2023-02-04 10:57:27 +00:00
JB
ae3ee81bb4 Rename Quota methods for clarity (#159) 2023-02-04 10:53:45 +00:00
JB
c2ca02d149 Move refreshDeadline to only trigger on successful transmission (#157) 2023-02-04 10:16:05 +00:00
Jeroen Rinzema
77a64d9c87 Include a listener accepting an existing net.Listener (#155) 2023-02-04 10:10:10 +00:00
Wind
8dec9cc962 invalid config type provided (#152)
* invalid config type provided

examples/persistence/bolt/main.go: invalid config type provided

* fixed ErrReceiveMaximum(receive maximum exceeded)

No quotas of the inflight is set in the readStore method, so each quota is equal to 0. The inheritClientSession method overrides the quotas of the new client inflight, so the processPublish method reports an ErrReceiveMaximum and disconnects the client.

* reset receive quota

receive quota should be reset across connections (as specified in the spec).
2023-02-04 10:06:26 +00:00
mochi-co
f90e52328d Update server version 2023-01-16 20:08:55 +00:00
JB
50aae47618 Publish retained messages only after connack (#147) 2023-01-16 19:50:01 +00:00
JB
0d79f2d63b Use Atomic instead of RWMutex for Hooks concurrency (#148)
* Use Atomic instead of RWMutex for Hooks concurrency
* Lock Hooks on Add Hook
2023-01-16 19:49:36 +00:00
JB
300152413c Ignore retain as published v3 (#142)
* Optimise Capabilities struct alignment

* Only use RetainAsPublished for v5 clients
2023-01-13 23:38:49 +00:00
mochi-co
0de1d731db Update version number 2023-01-10 00:01:21 +00:00
JB
80746abc52 Use correct connack return codes for MQTTv3 (#140) 2023-01-10 00:00:43 +00:00
mochi-co
a73cf4ca0e Update server version 2023-01-09 23:08:49 +00:00
mochi-co
bc549ee7ed Fix example imports 2023-01-09 22:52:24 +00:00
mochi-co
c464b46713 export client.Net.Conn for external use 2023-01-09 22:49:40 +00:00
mochi-co
05ce56008c Small code improvements 2023-01-09 22:49:20 +00:00
JB
8254cb0cbc Make hooks safe for concurrency (#139)
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
2023-01-09 22:41:44 +00:00
mochi-co
4ae58b79e3 Update server version 2023-01-07 20:13:48 +00:00
thedevop
b895d688e0 Change inline check order (#133) 2023-01-07 20:02:05 +00:00
mochi-co
a600cd4ead fix grammar on Closed method doc 2023-01-07 17:57:04 +00:00
mochi-co
cdb44990cf Update version number 2023-01-07 17:30:58 +00:00
mochi-co
2d9c128111 Refactor stored subscription value assignments 2023-01-07 17:30:30 +00:00
mochi-co
a0d5bdb39f Fix Typos 2023-01-07 17:30:01 +00:00
Wind
4ebcef3cb6 Save subscription properties for mqttv5 (#131)
* Update redis.go

Save the subscription properties for mqqtv5

* Update badger.go

Save the subscription properties for mqqtv5

* Update bolt.go

Save the subscription properties for mqqtv5

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>
2023-01-07 17:24:23 +00:00
thedevop
fb8d4720d7 Add Client Closed (#130)
* Add Client Closed
* Add Client Closed
* Update clients_test.go
2023-01-07 17:14:51 +00:00
JB
4080c89127 Update README.md 2022-12-21 21:00:53 +00:00
JB
1b67e6f3f6 Update README.md 2022-12-21 20:58:25 +00:00
mochi-co
1adb02e087 Update readme and server version 2022-12-21 20:47:58 +00:00
JB
4d4140aa99 Connect ReturnResponseInfo only applies to Connack values (#128) 2022-12-21 20:37:08 +00:00
JB
e31840a37d Optimize inflight expiry (#127)
* Small formatting/style changes for filter existed

* Use OnQoSDropped hook instead of onInflightExpired
2022-12-21 19:44:25 +00:00
JB
7d2e16f2d3 Merge pull request #123 from wind-c/master
Variable existed in the method processSubscribe is unstable
2022-12-21 11:41:14 +00:00
JB
92cd935a16 Merge branch 'master' into master 2022-12-21 11:38:28 +00:00
JB
25ce27ce2d Merge pull request #124 from zgwit/master
Add unix socket listener
2022-12-21 11:28:23 +00:00
jason
527d084a4b Add unix socket listener 2022-12-20 23:02:59 +08:00
Wind
bb9f937bb0 Variable existed in the method processSubscribe is unstable
The variable existed can be changed repeatedly within a for loop. An array variable must be used to record the subscription of each filter.
2022-12-18 13:46:06 +08:00
Wind
511fe88684 Merge branch 'mochi-co:master' into master 2022-12-17 12:33:09 +08:00
JB
75504ff201 Update server version 2022-12-16 18:27:29 +00:00
Wind
a556feb325 Add the OnUnsubscribed hook to the unsubscribeClient method (#122)
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-16 18:23:58 +00:00
“Wind”
d06f47f4b9 Add the OnUnsubscribed hook to the unsubscribeClient method
Add the OnUnsubscribed hook to the unsubscribeClient method,and change the unsubscribeClient to externally visible. In a clustered environment, if a client is disconnected and then connected to another node, the subscriptions on the previous node need to be cleared.
2022-12-17 00:40:06 +08:00
JB
8d4cc091b4 Update version number 2022-12-16 00:31:59 +00:00
240 changed files with 24085 additions and 3597 deletions

View File

@@ -7,37 +7,39 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
name: Test with Coverage
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Check out code
uses: actions/checkout@v3
- name: Install dependencies
run: |
go mod download
- name: Run Unit tests
run: |
go test -race -covermode atomic -coverprofile=covprofile ./...
- name: Install goveralls
run: go install github.com/mattn/goveralls@latest
- name: Send coverage
env:
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: goveralls -coverprofile=covprofile -service=github

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
cmd/mqtt
.DS_Store
*.db
.idea

165
README.md
View File

@@ -2,7 +2,7 @@
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master&v2)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
@@ -16,7 +16,7 @@ Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/m
### What is MQTT?
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
## What's new in Version 2.0.0?
## What's new in Version 2?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
Don't forget to use the new v2 import paths:
@@ -37,14 +37,14 @@ import "github.com/mochi-co/mqtt/v2"
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
- Developer-centric:
- Most core broker code is now exported and accessible, for total developer control.
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Full-featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
@@ -83,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"log"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
if err != nil {
@@ -112,10 +116,16 @@ Examples of running the broker with various configurations can be found in the [
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
- `listeners.NewTCP(...)` - A TCP listener.
- `listeners.NewWebsocket(...)` A Websocket listener.
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
| Listener | Usage |
|------------------------------|----------------------------------------------------------------------------------------------|
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewNet | A net.Listener listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
| listeners.NewHTTPHealthCheck | An HTTP healthcheck listener to provide health check responses for e.g. cloud infrastructure |
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
A `*listeners.Config` may be passed to configure TLS.
@@ -132,11 +142,13 @@ server := mqtt.New(&mqtt.Options{
ObscureNotAuthorized: true,
},
},
ClientNetWriteBufferSize: 4096,
ClientNetReadBufferSize: 4096,
SysTopicResendInterval: 10,
})
```
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
## Event Hooks
@@ -258,50 +270,50 @@ For more information on how the badger hook works, or how to use it, see the [ex
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
## Developing with Event Hooks
Many hooks are available for interacting with the broker and client lifecycle.
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
| Function | Usage |
| -------------------------- | -- |
| OnStarted | Called when the server has successfully started.|
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
| Function | Usage |
|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| OnStarted | Called when the server has successfully started. |
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects, may return an error or packet code to halt the client connection process. |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification. |
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnRetainPublished | Called then a retained message is published to a client. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
@@ -334,6 +346,8 @@ server.InjectPacket(cl, packets.Packet{
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
@@ -353,39 +367,54 @@ Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQ
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
> Benchmarks are provided as a general performance expectation guideline only.
> Benchmarks are provided as a general performance expectation guideline only. Comparisons are performed using out-of-the-box default configurations.
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
| Mochi v2.2.10 | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 |
| [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| [EMQX v5.0.11](https://github.com/emqx/emqx) | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
| [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
| Mochi v2.2.10 | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
| Rumqtt v0.21.0 | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
| Mochi v2.2.10 | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
| Rumqtt v0.21.0 | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 |
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
## Contribution Guidelines
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request. If you open a pull request, please try to follow the following guidelines:
- Try to maintain test coverage where reasonably possible.
- Clearly state what the PR does and why.
- Remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution.
[SPDX Annotations](https://spdx.dev) are used to clearly indicate the license, copyright, and contributions of each file in a machine-readable format. If you are adding a new file to the repository, please ensure it has the following SPDX header:
```go
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: Your name or alias <optional@email.address>
package name
```
Please ensure to add a new `SPDX-FileContributor` line for each contributor to the file. Refer to other files for examples. Please remember to do this, your contributions to this project are valuable and appreciated - it's important to receive credit!
## Stargazers over time 🥰
[![Stargazers over time](https://starchart.cc/mochi-co/mqtt.svg)](https://starchart.cc/mochi-co/mqtt)
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
## Contributions
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.

View File

@@ -7,6 +7,7 @@ package mqtt
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
@@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
@@ -106,7 +107,7 @@ type Client struct {
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
conn net.Conn // the net.Conn used to establish the connection
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
@@ -135,26 +136,35 @@ type Will struct {
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
endOnce sync.Once // only end once
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
keepalive uint16 // the number of seconds the connection can wait
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
open: ctx,
cancelOpen: cancel,
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
@@ -164,17 +174,33 @@ func newClient(c net.Conn, o *ops) *Client {
if c != nil {
cl.Net = ClientConnection{
conn: c,
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for {
select {
case pk := <-cl.State.outbound:
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
case <-cl.State.open.Done():
return
}
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
@@ -184,9 +210,9 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
@@ -195,11 +221,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
@@ -217,19 +238,17 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
if cl.Net.conn != nil {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
@@ -237,28 +256,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i + 1
started := i
overflowed := false
for {
if i >= 65535 {
overflowed = true
i = 1
} else {
i++
}
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
break
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
@@ -272,7 +293,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
@@ -290,17 +311,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
var deleted int64
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted++
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
@@ -310,11 +332,11 @@ func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
@@ -335,20 +357,20 @@ func (cl *Client) Read(packetHandler ReadFn) error {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.conn != nil {
_ = cl.Net.conn.Close() // omit close error
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
if cl.State.cancelOpen != nil {
cl.State.cancelOpen()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
@@ -361,6 +383,11 @@ func (cl *Client) StopCause() error {
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
@@ -383,7 +410,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize {
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
@@ -454,15 +481,14 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.conn == nil {
if cl.Net.Conn == nil {
return nil
}
defer cl.refreshDeadline(cl.State.keepalive)
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
@@ -476,8 +502,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
@@ -527,7 +553,11 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.conn)
n, err := func() (int64, error) {
cl.Lock()
defer cl.Unlock()
return nb.WriteTo(cl.Net.Conn)
}()
if err != nil {
return err
}

View File

@@ -5,6 +5,7 @@
package mqtt
import (
"context"
"errors"
"io"
"net"
@@ -29,9 +30,13 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
@@ -42,6 +47,9 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
@@ -107,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
@@ -125,10 +133,12 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.conn)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
@@ -155,7 +165,7 @@ func TestClientParseConnect(t *testing.T) {
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
@@ -163,8 +173,8 @@ func TestClientParseConnect(t *testing.T) {
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
@@ -237,28 +247,32 @@ func TestClientNextPacketIDInUse(t *testing.T) {
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Equal(t, uint32(0), i)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(65534)
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(65535), i)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientClearInflights(t *testing.T) {
@@ -272,7 +286,9 @@ func TestClientClearInflights(t *testing.T) {
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
cl.ClearInflights(n, 4)
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
@@ -318,7 +334,7 @@ func TestClientResendInflightMessagesNoMessages(t *testing.T) {
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
@@ -352,7 +368,7 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.capabilities.MaximumPacketSize = 2
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
@@ -451,7 +467,7 @@ func TestClientReadOK(t *testing.T) {
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
cl.State.cancelOpen()
o := make(chan error)
go func() {
@@ -468,10 +484,17 @@ func TestClientStop(t *testing.T) {
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.True(t, cl.Closed())
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
@@ -577,7 +600,7 @@ func TestClientReadPacket(t *testing.T) {
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.conn.Close()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
@@ -601,7 +624,7 @@ func TestClientWritePacket(t *testing.T) {
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.conn.Close()
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
@@ -683,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.conn.Close()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@@ -30,14 +30,14 @@ func main() {
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
ShowPacketData: true,
})
err = server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}

View File

@@ -110,8 +110,9 @@ func (h *ExampleHook) Init(config any) error {
return nil
}
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
return nil
}
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {

View File

@@ -26,10 +26,8 @@ func main() {
}()
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
@@ -62,6 +60,7 @@ func (h *pahoAuthHook) ID() string {
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnConnect,
mqtt.OnACLCheck,
}, []byte{b})
}
@@ -73,3 +72,12 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet)
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}
func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
// Handle paho test_server_keep_alive
if pk.Connect.Keepalive == 120 && pk.Connect.Clean {
cl.State.Keepalive = 60
cl.State.ServerKeepalive = true
}
return nil
}

View File

@@ -30,12 +30,15 @@ func main() {
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), bolt.Options{
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)

View File

@@ -97,7 +97,7 @@ func main() {
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
TLSConfig: tlsConfig,
}, nil)
}, server.Info)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)

View File

@@ -1,101 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
package mqtt
import (
"sync"
"sync/atomic"
xh "github.com/cespare/xxhash/v2"
)
// taskChan is a channel for incoming task functions.
type taskChan chan func()
// FanPool is a fixed-sized fan-style worker pool with multiple
// working 'columns'. Instead of a single queue channel processed by
// many goroutines, this fan pool uses many queue channels each
// processed by a single goroutine.
// Very special thanks are given to the authors of HMQ in particular
// @chowyu08 and @muXxer for their work on the fixpool worker pool
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
// from which this fan-pool is heavily inspired.
type FanPool struct {
queue []taskChan
wg sync.WaitGroup
capacity uint64
perChan uint64
Mutex sync.Mutex
}
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
// of the fan, whereas queueSize controls the size of each column's queue.
func NewFanPool(fanSize, queueSize uint64) *FanPool {
pool := &FanPool{
capacity: fanSize,
perChan: queueSize,
queue: make([]taskChan, fanSize),
}
pool.fillWorkers(fanSize)
return pool
}
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
func (p *FanPool) fillWorkers(n uint64) {
for i := uint64(0); i < n; i++ {
p.queue[i] = make(taskChan, p.perChan)
go p.worker(p.queue[i])
p.wg.Add(1)
}
}
// worker is a worker goroutine which processes tasks from a single queue.
func (p *FanPool) worker(ch taskChan) {
defer p.wg.Done()
var task func()
var ok bool
for {
task, ok = <-ch
if !ok {
return
}
task()
}
}
// Enqueue adds a new task to the queue to be processed.
func (p *FanPool) Enqueue(id string, task func()) {
if p.Size() == 0 {
return
}
// We can use xh.Sum64 to get a specific queue index
// which remains the same for a client id, giving each
// client their own queue.
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
}
// Wait blocks until all the workers in the pool have completed.
func (p *FanPool) Wait() {
p.wg.Wait()
}
// Close issues a shutdown signal to the workers.
func (p *FanPool) Close() {
for i := 0; i < int(p.Size()); i++ {
if p.queue[i] != nil {
close(p.queue[i])
}
}
p.queue = nil
atomic.StoreUint64(&p.capacity, 0)
}
// Size returns the current number of workers in the pool.
func (p *FanPool) Size() uint64 {
return atomic.LoadUint64(&p.capacity)
}

View File

@@ -1,89 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFanPool(t *testing.T) {
f := NewFanPool(1, 2)
require.NotNil(t, f)
require.Equal(t, uint64(1), f.capacity)
require.Equal(t, 2, cap(f.queue[0]))
o := make(chan bool)
go func() {
f.Enqueue("test", func() {
o <- true
})
}()
require.True(t, <-o)
f.Close()
f.Wait()
}
func TestFillWorkers(t *testing.T) {
f := &FanPool{
perChan: 3,
queue: make([]taskChan, 2),
}
f.fillWorkers(2)
require.Len(t, f.queue, 2)
require.Equal(t, 3, cap(f.queue[0]))
}
func TestEnqueue(t *testing.T) {
f := &FanPool{
capacity: 2,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
},
}
go func() {
f.Enqueue("a", func() {})
}()
require.NotNil(t, <-f.queue[1])
}
func TestEnqueueOnEmpty(t *testing.T) {
f := &FanPool{
queue: []taskChan{},
}
go func() {
f.Enqueue("a", func() {})
}()
require.Len(t, f.queue, 0)
}
func TestSize(t *testing.T) {
f := &FanPool{
capacity: 10,
}
require.Equal(t, uint64(10), f.Size())
}
func TestClose(t *testing.T) {
f := &FanPool{
capacity: 3,
queue: []taskChan{
make(taskChan, 2),
make(taskChan, 2),
make(taskChan, 2),
},
}
f.Close()
require.Equal(t, uint64(0), f.Size())
require.Nil(t, f.queue)
}

6
go.mod
View File

@@ -6,7 +6,6 @@ require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/cespare/xxhash/v2 v2.1.2
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
@@ -21,6 +20,7 @@ require (
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
@@ -33,8 +33,8 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.1 // indirect
)

10
go.sum
View File

@@ -109,8 +109,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -118,11 +118,11 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=

192
hooks.go
View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
@@ -39,15 +39,17 @@ const (
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnRetainPublished
OnQosPublish
OnQosComplete
OnQosDropped
OnPacketIDExhausted
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
OnExpireInflights
StoredClients
StoredSubscriptions
StoredInflightMessages
@@ -73,7 +75,7 @@ type Hook interface {
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet)
OnConnect(cl *Client, pk packets.Packet) error
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
@@ -88,15 +90,17 @@ type Hook interface {
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnRetainPublished(cl *Client, pk packets.Packet)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnPacketIDExhausted(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
OnExpireInflights(cl *Client, expiry int64)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
@@ -112,10 +116,10 @@ type HookOptions struct {
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal []Hook // a slice of hooks
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
@@ -125,7 +129,7 @@ func (h *Hooks) Len() int64 {
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
@@ -140,26 +144,39 @@ func (h *Hooks) Provides(b ...byte) bool {
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
if h.internal == nil {
h.internal = []Hook{}
}
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
h.internal = append(h.internal, hook)
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
@@ -174,7 +191,7 @@ func (h *Hooks) Stop() {
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
@@ -183,7 +200,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
@@ -192,25 +209,29 @@ func (h *Hooks) OnStarted() {
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
// OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
err := hook.OnConnect(cl, pk)
if err != nil {
return err
}
}
}
return nil
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
@@ -219,7 +240,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
@@ -229,7 +250,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
@@ -250,7 +271,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
@@ -266,7 +287,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
@@ -277,7 +298,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
@@ -287,7 +308,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
@@ -299,7 +320,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
@@ -309,7 +330,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
@@ -321,7 +342,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
@@ -334,7 +355,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
@@ -344,7 +365,7 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
@@ -356,16 +377,17 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
if err != nil {
if errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
return pk, err
}
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet error")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
@@ -375,27 +397,46 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnRetainPublished is called when a retained message is published.
func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainPublished) {
hook.OnRetainPublished(cl, pk)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
@@ -406,7 +447,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
@@ -414,22 +455,32 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned.
// It is typically used to delete an inflight message from a store.
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
// assign to a packet.
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketIDExhausted) {
hook.OnPacketIDExhausted(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
@@ -445,7 +496,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
@@ -454,7 +505,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
@@ -463,7 +514,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
@@ -473,7 +524,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
@@ -493,7 +544,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
@@ -513,7 +564,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
@@ -533,7 +584,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
@@ -552,7 +603,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
@@ -574,7 +625,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
@@ -590,7 +641,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.internal {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
@@ -601,19 +652,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnExpireInflights is called when the server issues a clear request for expired
// inflight messages. Expiry should be the time after which the message is no longer
// valid (usually some time in the past). A message has expired if it's created time
// is older than time.Now() minus Inflight TTL. This method can be used to expire
// old inflight messages in a persistent store which doesnt support per-item TTL.
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
for _, hook := range h.internal {
if hook.Provides(OnExpireInflights) {
hook.OnExpireInflights(cl, expiry)
}
}
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
@@ -646,7 +684,7 @@ func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
h.Opts = opts
}
// Stop is called to gracefully shutdown the hook.
// Stop is called to gracefully shut down the hook.
func (h *HookBase) Stop() error {
return nil
}
@@ -671,7 +709,9 @@ func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error {
return nil
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
@@ -729,9 +769,15 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnRetainPublished is called when a retained message is published.
func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
@@ -741,6 +787,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
@@ -755,9 +804,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return

View File

@@ -80,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -183,6 +182,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
@@ -199,11 +202,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
@@ -348,32 +355,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.Delete(m.ID, new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
@@ -382,6 +370,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")

View File

@@ -5,13 +5,11 @@
package badger
import (
"errors"
"os"
"strings"
"testing"
"time"
"github.com/asdine/storm/v3"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
@@ -170,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -219,6 +232,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -333,6 +369,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -419,48 +470,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -185,6 +184,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
@@ -201,12 +204,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
@@ -369,34 +377,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var v []storage.Message
err := h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
return
}
for _, m := range v {
if m.Created < expiry || m.Created == 0 {
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
return
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
@@ -404,6 +391,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -5,7 +5,6 @@
package bolt
import (
"errors"
"os"
"testing"
"time"
@@ -212,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -227,6 +241,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h.config.Path, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -341,6 +378,21 @@ func TestOnRetainedExpired(t *testing.T) {
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
@@ -427,48 +479,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var v []storage.Message
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
require.Len(t, v, 1)
require.Equal(t, "i1", v[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)

View File

@@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.OnExpireInflights,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
@@ -200,6 +199,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return
}
if cl.StopCause() == packets.ErrSessionTakenOver {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
@@ -216,11 +219,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Filter: pk.Filters[i].Filter,
Qos: reasonCodes[i],
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
@@ -364,37 +371,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
}
}
// OnExpireInflights removes all inflight messages which have passed the
// provided expiry time.
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
if d.Created < expiry || d.Created == 0 {
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
@@ -403,6 +386,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")

View File

@@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -269,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectSessionTakenOver(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
testClient := &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
testClient.Stop(packets.ErrSessionTakenOver)
teardown(t, h)
h.OnDisconnect(testClient, nil, true)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -392,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) {
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
@@ -484,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
h.OnQosDropped(client, packets.Packet{})
}
func TestOnExpireInflights(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
n := time.Now().Unix()
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
&storage.Message{ID: "i3", T: storage.InflightKey},
).Err()
require.NoError(t, err)
h.OnExpireInflights(client, time.Now().Unix()-10)
var r []storage.Message
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
require.NoError(t, err)
require.Len(t, rows, 1)
for _, row := range rows {
var d storage.Message
err = d.UnmarshalBinary([]byte(row))
require.NoError(t, err)
r = append(r, d)
}
require.Len(t, r, 1)
require.Equal(t, "i1", r[0].ID)
}
func TestOnExpireInflightsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnExpireInflightsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnExpireInflights(client, time.Now().Unix()-10)
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()

View File

@@ -117,6 +117,36 @@ func (d *Message) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, d)
}
// ToPacket converts a storage.Message to a standard packet.
func (d *Message) ToPacket() packets.Packet {
pk := packets.Packet{
FixedHeader: d.FixedHeader,
PacketID: d.PacketID,
TopicName: d.TopicName,
Payload: d.Payload,
Origin: d.Origin,
Created: d.Created,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
}
// Return a deep copy of the packet data otherwise the slices will
// continue pointing at the values from the storage packet.
pk = pk.Copy(true)
pk.FixedHeader.Dup = d.FixedHeader.Dup
return pk
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`

View File

@@ -104,6 +104,7 @@ var (
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
@@ -111,7 +112,7 @@ var (
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
@@ -193,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}
func TestMessageToPacket(t *testing.T) {
d := messageStruct
pk := d.ToPacket()
require.Equal(t, packets.Packet{
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: d.FixedHeader.Remaining,
Type: d.FixedHeader.Type,
Qos: d.FixedHeader.Qos,
Dup: d.FixedHeader.Dup,
Retain: d.FixedHeader.Retain,
},
Origin: d.Origin,
TopicName: d.TopicName,
Properties: packets.Properties{
PayloadFormat: d.Properties.PayloadFormat,
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
ContentType: d.Properties.ContentType,
ResponseTopic: d.Properties.ResponseTopic,
CorrelationData: d.Properties.CorrelationData,
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
TopicAlias: d.Properties.TopicAlias,
User: d.Properties.User,
},
PacketID: 100,
Created: d.Created,
}, pk)
}

View File

@@ -27,6 +27,10 @@ type modifiedHookBase struct {
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
@@ -46,6 +50,14 @@ func (h *modifiedHookBase) Stop() error {
return nil
}
func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
@@ -178,12 +190,20 @@ func TestHooksProvides(t *testing.T) {
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddAndLen(t *testing.T) {
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
@@ -216,7 +236,6 @@ func TestHooksNonReturns(t *testing.T) {
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnConnect(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
@@ -224,14 +243,16 @@ func TestHooksNonReturns(t *testing.T) {
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnRetainPublished(cl, packets.Packet{})
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnPacketIDExhausted(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
h.OnExpireInflights(cl, time.Now().Unix()-1)
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
@@ -325,7 +346,7 @@ func TestHooksOnPublish(t *testing.T) {
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
@@ -380,6 +401,22 @@ func TestHooksOnAuthPacket(t *testing.T) {
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnConnect(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
hook.fail = true
err = h.OnConnect(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = &logger
@@ -552,12 +589,19 @@ func TestHookBaseOnConnectAuthenticate(t *testing.T) {
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnConnect(t *testing.T) {
h := new(HookBase)
err := h.OnConnect(new(Client), packets.Packet{})
require.NoError(t, err)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})

View File

@@ -58,6 +58,18 @@ func (i *Inflight) Len() int {
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
@@ -104,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool {
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) TakeReceiveQuota() {
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) ReturnReceiveQuota() {
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
@@ -123,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// TakeSendQuota reduces the send quota by 1.
func (i *Inflight) TakeSendQuota() {
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// ReturnSendQuota increases the send quota by 1.
func (i *Inflight) ReturnSendQuota() {
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}

View File

@@ -61,6 +61,16 @@ func TestInflightLen(t *testing.T) {
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newTestClient()
@@ -95,12 +105,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.ReturnReceiveQuota()
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
@@ -110,12 +120,12 @@ func TestReceiveQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.TakeReceiveQuota()
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
@@ -137,12 +147,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.ReturnSendQuota()
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
@@ -152,12 +162,12 @@ func TestSendQuota(t *testing.T) {
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.TakeSendQuota()
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-co
// SPDX-FileContributor: Derek Duncan
package listeners
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
)
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
type HTTPHealthCheck struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address.
func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck {
if config == nil {
config = new(Config)
}
return &HTTPHealthCheck{
id: id,
address: address,
config: config,
}
}
// ID returns the id of the listener.
func (l *HTTPHealthCheck) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *HTTPHealthCheck) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *HTTPHealthCheck) Protocol() string {
if l.listen != nil && l.listen.TLSConfig != nil {
return "https"
}
return "http"
}
// Init initializes the listener.
func (l *HTTPHealthCheck) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
}
})
l.listen = &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
Addr: l.address,
Handler: mux,
}
if l.config.TLSConfig != nil {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPHealthCheck) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPHealthCheck) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}

View File

@@ -0,0 +1,143 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-co
// SPDX-FileContributor: Derek Duncan
package listeners
import (
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewHTTPHealthCheck(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "healthcheck", l.id)
require.Equal(t, testAddr, l.address)
}
func TestHTTPHealthCheckID(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "healthcheck", l.ID())
}
func TestHTTPHealthCheckAddress(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestHTTPHealthCheckProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
require.Equal(t, "http", l.Protocol())
}
func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
l.Init(nil)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPHealthCheckInit(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr)
}
func TestHTTPHealthCheckServeAndClose(t *testing.T) {
// setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// call healthcheck
resp, err := http.Get("http://localhost" + testAddr + "/healthcheck")
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck")
require.Error(t, err)
<-o
}
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
// setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// make disallowed method type http request
resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
require.NoError(t, err)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody)
require.Error(t, err)
<-o
}
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.Close(MockCloser)
}

View File

@@ -107,28 +107,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := &system.Info{
Version: l.sysInfo.Version,
Started: atomic.LoadInt64(&l.sysInfo.Started),
Time: atomic.LoadInt64(&l.sysInfo.Time),
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
}
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {

92
listeners/net.go Normal file
View File

@@ -0,0 +1,92 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-co
// SPDX-FileContributor: Jeroen Rinzema
package listeners
import (
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *zerolog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

105
listeners/net_test.go Normal file
View File

@@ -0,0 +1,105 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

98
listeners/unixsock.go Normal file
View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@@ -154,7 +154,7 @@ func (ws *wsConn) Read(p []byte) (int, error) {
br, err = r.Read(p[n:])
n += br
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err

View File

@@ -376,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(65535)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}

View File

@@ -21,7 +21,7 @@ func (c Code) Error() string {
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
// QosCodes indicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
@@ -113,15 +113,35 @@ var (
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@@ -8,6 +8,7 @@ import (
"bytes"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
@@ -15,22 +16,23 @@ import (
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1
Connack // 2
Publish // 3
Puback // 4
Pubrec // 5
Pubrel // 6
Pubcomp // 7
Subscribe // 8
Suback // 9
Unsubscribe // 10
Unsuback // 11
Pingreq // 12
Pingresp // 13
Disconnect // 14
Auth // 15
WillProperties byte = 99 // Special byte for validating Will Properties.
)
var (
@@ -208,7 +210,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
Created: pk.Created,
Expiry: pk.Expiry,
Origin: pk.Origin,
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
}
if allowTransfer {
p.PacketID = pk.PacketID
}
if len(pk.Connect.ProtocolName) > 0 {
@@ -309,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Properties).Encode(pk, pb, 0)
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -318,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes())
}
@@ -379,7 +384,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
@@ -389,11 +394,11 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
if pk.ProtocolVersion == 5 {
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
if err != nil {
return ErrMalformedWillProperties
}
offset += n + 1
offset += n
}
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
@@ -408,6 +413,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
}
if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
if offset >= len(buf) { // we are at the end of the packet
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
pk.Connect.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
@@ -439,18 +448,14 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
}
if len(pk.Connect.Password) > 65535 {
if len(pk.Connect.Password) > math.MaxUint16 {
return ErrProtocolViolationPasswordTooLong
}
if len(pk.Connect.Username) > 65535 {
if len(pk.Connect.Username) > math.MaxUint16 {
return ErrProtocolViolationUsernameTooLong
}
if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 {
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
}
if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
}
@@ -463,7 +468,7 @@ func (pk *Packet) ConnectValidate() Code {
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
}
if len(pk.Connect.ClientIdentifier) > 65535 {
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
return ErrClientIdentifierNotValid
}
@@ -492,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes())
}
@@ -535,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -604,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes())
}
@@ -640,7 +645,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.Payload = buf[offset:]
@@ -688,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode)
}
@@ -829,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes())
}
@@ -857,7 +862,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
pk.ReasonCodes = buf[offset:]
@@ -886,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -914,7 +919,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -981,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
}
@@ -1010,7 +1015,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
pk.ReasonCodes = buf[offset:]
}
@@ -1034,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes())
}
@@ -1062,7 +1067,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
if err != nil {
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
}
offset += n + 1
offset += n
}
var filter string
@@ -1097,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len())
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes())
pk.FixedHeader.Remaining = nb.Len()

View File

@@ -464,6 +464,9 @@ func TestCopy(t *testing.T) {
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}

View File

@@ -42,11 +42,11 @@ const (
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
@@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
@@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
@@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
@@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
@@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
@@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
@@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
@@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes())
}
}
@@ -361,18 +359,19 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
n, _, err = DecodeLength(b)
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n, err
return n + bu, err
}
if n == 0 {
return n, nil
return n + bu, nil
}
bt := b.Bytes()
@@ -380,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
}
switch k {
@@ -406,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n, err
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
@@ -452,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n, err
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
@@ -470,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
}
if err != nil {
return n, err
return n + bu, err
}
}
return n, nil
return n + bu, nil
}

View File

@@ -202,14 +202,14 @@ func init() {
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
@@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
@@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
@@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
@@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172, n)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}

View File

@@ -71,7 +71,7 @@ const (
TConnectInvalidWillFlagNoPayload
TConnectInvalidWillFlagQosOutOfRange
TConnectInvalidWillSurplusRetain
TConnectNotCleanNoClientID
TConnectZeroByteUsername
TConnectSpecInvalidUTF8D800
TConnectSpecInvalidUTF8DFFF
TConnectSpecInvalidUTF80000
@@ -82,6 +82,7 @@ const (
TConnackAcceptedAdjustedExpiryInterval
TConnackMinMqtt5
TConnackMinCleanMqtt5
TConnackServerKeepalive
TConnackInvalidMinMqtt5
TConnackBadProtocolVersion
TConnackProtocolViolationNoSession
@@ -89,6 +90,7 @@ const (
TConnackServerUnavailable
TConnackBadUsernamePassword
TConnackBadUsernamePasswordNoSession
TConnackMqtt5BadUsernamePasswordNoSession
TConnackNotAuthorised
TConnackMalSessionPresent
TConnackMalReturnCode
@@ -249,26 +251,26 @@ var TPacketData = map[byte]TPacketCases{
Desc: "mqtt v3.1.1",
Primary: true,
RawBytes: []byte{
Connect << 4, 16, // Fixed header
Connect << 4, 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Packet Flags
0, 60, // Keepalive
0, 4, // Client ID - MSB+LSB
'z', 'e', 'n', '3', // Client ID "zen"
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 16,
Remaining: 15,
},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: false,
Keepalive: 60,
ClientIdentifier: "zen3",
ClientIdentifier: "zen",
},
},
},
@@ -425,9 +427,9 @@ var TPacketData = map[byte]TPacketCases{
Connect << 4, 28, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
194, // Packet Flags
0, 20, // Keepalive
4, // Protocol Version
0 | 1<<6 | 1<<7, // Packet Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 5, // Username MSB+LSB
@@ -443,7 +445,7 @@ var TPacketData = map[byte]TPacketCases{
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Clean: false,
Keepalive: 20,
ClientIdentifier: "zen",
UsernameFlag: true,
@@ -497,6 +499,43 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectZeroByteUsername,
Desc: "username flag but 0 byte username",
Group: "decode",
RawBytes: []byte{
Connect << 4, 23, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Packet Flags
0, 30, // Keepalive
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 0, // Username MSB+LSB
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connect,
Remaining: 23,
},
ProtocolVersion: 5,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
Clean: true,
Keepalive: 30,
ClientIdentifier: "zen",
Username: []byte{},
UsernameFlag: true,
},
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
},
},
},
// Fail States
{
@@ -623,6 +662,24 @@ var TPacketData = map[byte]TPacketCases{
'm', 'o', 'c',
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "username flag with no username bytes",
Group: "decode",
FailFirst: ErrProtocolViolationFlagNoUsername,
RawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
130, // Flags
0, 20, // Keepalive
0,
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
},
},
{
Case: TConnectMalPassword,
Desc: "malformed password",
@@ -783,20 +840,6 @@ var TPacketData = map[byte]TPacketCases{
},
},
},
{
Case: TConnectInvalidFlagNoUsername,
Desc: "has username flag but no username",
Group: "validate",
Expect: ErrProtocolViolationFlagNoUsername,
Packet: &Packet{
FixedHeader: FixedHeader{Type: Connect},
ProtocolVersion: 4,
Connect: ConnectParams{
ProtocolName: []byte("MQTT"),
UsernameFlag: true,
},
},
},
{
Case: TConnectInvalidUsernameNoFlag,
Desc: "has username but no flag",
@@ -1043,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted, no session, adjusted expiry interval mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 11, // fixed header
Connack << 4, 8, // fixed header
0, // Session present
CodeSuccess.Code,
8, // length
5, // length
17, 0, 0, 0, 120, // Session Expiry Interval (17)
19, 0, 10, // Server Keep Alive (19)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 11,
Remaining: 8,
},
ReasonCode: CodeSuccess.Code,
Properties: Properties{
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
},
},
},
@@ -1148,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5",
Primary: true,
RawBytes: []byte{
Connack << 4, 16, // fixed header
Connack << 4, 13, // fixed header
1, // existing session
CodeSuccess.Code,
13, // Properties length
10, // Properties length
18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18)
19, 0, 20, // Server Keep Alive (19)
36, 1, // Maximum Qos (36)
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 16,
Remaining: 13,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
AssignedClientID: "mochi",
MaximumQos: byte(1),
MaximumQosFlag: true,
},
},
},
@@ -1178,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{
Desc: "accepted min properties mqtt5b",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
Connack << 4, 3, // fixed header
0, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
0, // Properties length
},
Packet: &Packet{
ProtocolVersion: 5,
@@ -1192,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{
},
SessionPresent: false,
ReasonCode: CodeSuccess.Code,
},
},
{
Case: TConnackServerKeepalive,
Desc: "server set keepalive",
Primary: true,
RawBytes: []byte{
Connack << 4, 6, // fixed header
1, // existing session
CodeSuccess.Code,
3, // Properties length
19, 0, 10, // server keepalive
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 6,
},
SessionPresent: true,
ReasonCode: CodeSuccess.Code,
Properties: Properties{
ServerKeepAlive: uint16(10),
ServerKeepAliveFlag: true,
@@ -1203,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{
Desc: "failure min properties mqtt5",
Primary: true,
RawBytes: append([]byte{
Connack << 4, 26, // fixed header
Connack << 4, 23, // fixed header
0, // No existing session
ErrUnspecifiedError.Code,
// Properties
23, // length
19, 0, 20, // Server Keep Alive (19)
20, // length
31, 0, 17, // Reason String (31)
}, []byte(ErrUnspecifiedError.Reason)...),
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 25,
Remaining: 23,
},
SessionPresent: false,
ReasonCode: ErrUnspecifiedError.Code,
Properties: Properties{
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
ReasonString: ErrUnspecifiedError.Reason,
ReasonString: ErrUnspecifiedError.Reason,
},
},
},
@@ -1316,10 +1370,28 @@ var TPacketData = map[byte]TPacketCases{
Desc: "bad username or password no session",
RawBytes: []byte{
Connack << 4, 2, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0, // No session present
Err3NotAuthorized.Code, // use v3 remapping
},
Packet: &Packet{
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
},
ReasonCode: Err3NotAuthorized.Code,
},
},
{
Case: TConnackMqtt5BadUsernamePasswordNoSession,
Desc: "mqtt5 bad username or password no session",
RawBytes: []byte{
Connack << 4, 3, // fixed header
0, // No session present
ErrBadUsernameOrPassword.Code,
0,
},
Packet: &Packet{
ProtocolVersion: 5,
FixedHeader: FixedHeader{
Type: Connack,
Remaining: 2,
@@ -1327,6 +1399,7 @@ var TPacketData = map[byte]TPacketCases{
ReasonCode: ErrBadUsernameOrPassword.Code,
},
},
{
Case: TConnackNotAuthorised,
Desc: "not authorised",
@@ -1804,13 +1877,10 @@ var TPacketData = map[byte]TPacketCases{
Case: TPublishRetainMqtt5,
Desc: "retain mqtt5",
RawBytes: []byte{
Publish<<4 | 1<<0, 35, // Fixed header
Publish<<4 | 1<<0, 19, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
16, // properties length
38, // User Properties (38)
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
0, // properties length
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
},
Packet: &Packet{
@@ -1818,18 +1888,11 @@ var TPacketData = map[byte]TPacketCases{
FixedHeader: FixedHeader{
Type: Publish,
Retain: true,
Remaining: 35,
Remaining: 19,
},
TopicName: "a/b/c",
Properties: Properties{
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
},
},
Payload: []byte("hello mochi"),
TopicName: "a/b/c",
Properties: Properties{},
Payload: []byte("hello mochi"),
},
},
{

341
server.go
View File

@@ -26,10 +26,8 @@ import (
)
const (
Version = "2.0.5" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
defaultFanPoolSize uint64 = 32 // the number of concurrent workers in the pool
defaultFanPoolQueueSize uint64 = 1024 // the capacity of each worker queue
Version = "2.2.14" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
)
var (
@@ -45,8 +43,8 @@ var (
WildcardSubAvailable: 1, // wildcard subscriptions are available
SubIDAvailable: 1, // subscription identifiers are available
SharedSubAvailable: 1, // shared subscriptions are available
ServerKeepAlive: 10, // default keepalive for clients
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
}
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
@@ -56,18 +54,19 @@ var (
// Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct {
MaximumMessageExpiryInterval int64
MaximumClientWritesPending int32
MaximumSessionExpiryInterval uint32
MaximumPacketSize uint32
maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16
TopicAliasMaximum uint16
ServerKeepAlive uint16
SharedSubAvailable byte
MinimumProtocolVersion byte
Compatibilities Compatibilities
MaximumQos byte
RetainAvailable byte
WildcardSubAvailable byte
SubIDAvailable byte
SharedSubAvailable byte
MinimumProtocolVersion byte
}
// Compatibilities provides flags for using compatibility modes.
@@ -80,9 +79,17 @@ type Compatibilities struct {
// Options contains configurable options for the server.
type Options struct {
// Capabilities defines the server features and behaviour.
// Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
ClientNetWriteBufferSize int
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
ClientNetReadBufferSize int
// Logger specifies a custom configured implementation of zerolog to override
// the servers default logger configuration. If you wish to change the log level,
// of the default logger, you can do so by setting
@@ -91,16 +98,6 @@ type Options struct {
// server.Log = &l
Logger *zerolog.Logger
// FanPoolSize is the number of individual workers and queues to initialize.
// Bigger is not necessarily better, and you should rely on defaults unless
// you have know what you are doing.
FanPoolSize uint64
// FanPoolQueueSize is the size of the queue per worker. Increase this value
// accordingly if you anticipate having intermittent but massive numbers of
// messages. Cluster support is roadmapped.
FanPoolQueueSize uint64
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64
}
@@ -113,7 +110,6 @@ type Server struct {
Clients *Clients // clients known to the broker
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
Info *system.Info // values about the server commonly known as $SYS topics
fanpool *FanPool // a fixed size worker pool for processing inbound and outbound messages
loop *loop // loop contains tickers for the system event loop
done chan bool // indicate that the server is ending
Log *zerolog.Logger // minimal no-alloc logger
@@ -132,10 +128,10 @@ type loop struct {
// ops contains server values which can be propagated to other structs.
type ops struct {
capabilities *Capabilities // a pointer to the server capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
options *Options // a pointer to the server options and capabilities, for referencing in clients
info *system.Info // pointers to server system info
hooks *Hooks // pointer to the server hooks
log *zerolog.Logger // a structured logger for the client
}
// New returns a new instance of mochi mqtt broker. Optional parameters
@@ -165,8 +161,7 @@ func New(opts *Options) *Server {
Version: Version,
Started: time.Now().Unix(),
},
fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize),
Log: opts.Logger,
Log: opts.Logger,
hooks: &Hooks{
Log: opts.Logger,
},
@@ -181,16 +176,18 @@ func (o *Options) ensureDefaults() {
o.Capabilities = DefaultServerCapabilities
}
o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
if o.SysTopicResendInterval == 0 {
o.SysTopicResendInterval = defaultSysTopicInterval
}
if o.FanPoolSize == 0 {
o.FanPoolSize = defaultFanPoolSize
if o.ClientNetWriteBufferSize == 0 {
o.ClientNetWriteBufferSize = 1024 * 2
}
if o.FanPoolQueueSize < 1 {
o.FanPoolQueueSize = defaultFanPoolQueueSize
if o.ClientNetReadBufferSize == 0 {
o.ClientNetReadBufferSize = 1024 * 2
}
if o.Logger == nil {
@@ -205,10 +202,10 @@ func (o *Options) ensureDefaults() {
// topic validation checks.
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
capabilities: s.Options.Capabilities,
info: s.Info,
hooks: s.hooks,
log: s.Log,
options: s.Options,
info: s.Info,
hooks: s.hooks,
log: s.Log,
})
cl.ID = id
@@ -216,9 +213,11 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool)
if inline { // inline clients bypass acl and some validity checks.
cl.Net.Inline = true
// By default we don't want to restrict developer publishes,
// By default, we don't want to restrict developer publishes,
// but if you do, reset this after creating inline client.
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
} else {
go cl.WriteLoop() // can only write to real clients
}
return cl
@@ -323,16 +322,20 @@ func (s *Server) attachClient(cl *Client, listener string) error {
cl.ParseConnect(listener, pk)
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
if code != packets.CodeSuccess {
if err := s.sendConnack(cl, code, false); err != nil {
if err := s.SendConnack(cl, code, false, nil); err != nil {
return fmt.Errorf("invalid connection send ack: %w", err)
}
return code // [MQTT-3.2.2-7] [MQTT-3.1.4-6]
}
s.hooks.OnConnect(cl, pk)
err = s.hooks.OnConnect(cl, pk)
if err != nil {
return err
}
cl.refreshDeadline(cl.State.Keepalive)
if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2]
err := s.sendConnack(cl, packets.ErrBadUsernameOrPassword, false)
err := s.SendConnack(cl, packets.ErrBadUsernameOrPassword, false, nil)
if err != nil {
return fmt.Errorf("invalid connection send ack: %w", err)
}
@@ -346,7 +349,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
sessionPresent := s.inheritClientSession(pk, cl)
s.Clients.Add(cl) // [MQTT-4.1.0-1]
err = s.sendConnack(cl, code, sessionPresent) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1]
err = s.SendConnack(cl, code, sessionPresent, nil) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1]
if err != nil {
return fmt.Errorf("ack connection packet: %w", err)
}
@@ -366,18 +369,17 @@ func (s *Server) attachClient(cl *Client, listener string) error {
if err != nil {
s.sendLWT(cl)
cl.Stop(err)
}
if err == nil {
} else {
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
}
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)
if expire {
s.unsubscribeClient(cl)
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
cl.ClearInflights(math.MaxInt64, 0)
s.UnsubscribeClient(cl)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
}
@@ -451,25 +453,40 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.unsubscribeClient(existing)
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
return false // [MQTT-3.2.2-3]
atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
return false // [MQTT-3.2.2-3]
}
atomic.StoreUint32(&existing.State.isTakenOver, 1)
if existing.State.Inflight.Len() > 0 {
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
}
}
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
for _, sub := range existing.State.Subscriptions.GetAll() {
existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub)
s.publishRetainedToClient(cl, sub, existed)
}
// Clean the state of the existing client to prevent sequential take-overs
// from increasing memory usage by inflights + subs * client-id.
s.UnsubscribeClient(existing)
existing.ClearInflights(math.MaxInt64, 0)
s.Log.Debug().Str("client", cl.ID).
Str("old_remote", existing.Net.Remote).
Str("new_remote", cl.Net.Remote).
Msg("session taken over")
return true // [MQTT-3.2.2-3]
}
@@ -480,15 +497,27 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
return false // [MQTT-3.2.2-2]
}
// sendConnack returns a Connack packet to a client.
func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) error {
properties := packets.Properties{
ServerKeepAlive: s.Options.Capabilities.ServerKeepAlive, // [MQTT-3.1.2-21]
ServerKeepAliveFlag: true,
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum
// SendConnack returns a Connack packet to a client.
func (s *Server) SendConnack(cl *Client, reason packets.Code, present bool, properties *packets.Properties) error {
if properties == nil {
properties = &packets.Properties{
ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum,
}
}
properties.ReceiveMaximum = s.Options.Capabilities.ReceiveMaximum // 3.2.2.3.3 Receive Maximum
if cl.State.ServerKeepalive { // You can set this dynamically using the OnConnect hook.
properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21]
properties.ServerKeepAliveFlag = true
}
if reason.Code >= packets.ErrUnspecifiedError.Code {
if cl.Properties.ProtocolVersion < 5 {
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
reason = v3reason
}
}
properties.ReasonString = reason.Reason
ack := packets.Packet{
FixedHeader: packets.FixedHeader{
@@ -496,9 +525,8 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
},
SessionPresent: false, // [MQTT-3.2.2-6]
ReasonCode: reason.Code, // [MQTT-3.2.2-8]
Properties: properties,
Properties: *properties,
}
return cl.WritePacket(ack)
}
@@ -518,14 +546,15 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
cl.Properties.Props.SessionExpiryIntervalFlag = true
}
return cl.WritePacket(packets.Packet{
ack := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: present,
ReasonCode: reason.Code, // [MQTT-3.2.2-8]
Properties: properties,
})
Properties: *properties,
}
return cl.WritePacket(ack)
}
// processPacket processes an inbound packet for a client. Since the method is
@@ -588,7 +617,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(next.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.TakeSendQuota()
cl.State.Inflight.DecreaseSendQuota()
}
}
@@ -649,7 +678,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
// processPublish processes a Publish packet.
func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) {
return nil
}
@@ -657,20 +686,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8]
}
if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline {
if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
return nil
}
pk.Origin = cl.ID
pk.Created = time.Now().Unix()
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok && !cl.Net.Inline {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
if !cl.Net.Inline {
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
return cl.WritePacket(ack)
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
atomic.AddInt64(&s.Info.Inflight, -1)
}
}
}
@@ -693,15 +724,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
}
if pk.FixedHeader.Qos == 0 {
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
cl.State.Inflight.TakeReceiveQuota()
cl.State.Inflight.DecreaseReceiveQuota()
ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4]
if pk.FixedHeader.Qos == 2 {
ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8]
@@ -709,6 +737,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Set(ack); ok {
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
}
err := cl.WritePacket(ack)
@@ -720,20 +749,18 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
if ok := cl.State.Inflight.Delete(ack.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
}
cl.State.Inflight.ReturnReceiveQuota()
s.hooks.OnQosComplete(cl, pk)
cl.State.Inflight.IncreaseReceiveQuota()
s.hooks.OnQosComplete(cl, ack)
}
s.fanpool.Enqueue(cl.ID, func() {
s.publishToSubscribers(pk)
})
s.publishToSubscribers(pk)
s.hooks.OnPublished(cl, pk)
return nil
}
// retainMessage adds a message to a topic, and if a persistent store is provided,
// adds the message to the store so it can be reloaded if necessary.
// adds the message to the store to be reloaded if necessary.
func (s *Server) retainMessage(cl *Client, pk packets.Packet) {
out := pk.Copy(false)
r := s.Topics.RetainMessage(out)
@@ -771,13 +798,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
}
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
if sub.NoLocal && pk.Origin == cl.ID {
return pk, nil // [MQTT-3.8.3-3]
}
out = pk.Copy(false)
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out := pk.Copy(false)
if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
}
@@ -807,6 +834,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if out.FixedHeader.Qos > 0 {
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
if err != nil {
s.hooks.OnPacketIDExhausted(cl, pk)
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
return out, packets.ErrQuotaExceeded
}
@@ -817,22 +845,32 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3]
atomic.AddInt64(&s.Info.Inflight, 1)
s.hooks.OnQosPublish(cl, out, out.Created, 0)
cl.State.Inflight.DecreaseSendQuota()
}
if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
out.Expiry = -1
cl.State.Inflight.Set(out)
return pk, nil
return out, nil
}
}
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
return pk, packets.CodeDisconnect
if cl.Net.Conn == nil || cl.Closed() {
return out, packets.CodeDisconnect
}
cl.State.Inflight.TakeSendQuota()
select {
case cl.State.outbound <- &out:
atomic.AddInt32(&cl.State.outboundQty, 1)
default:
atomic.AddInt64(&s.Info.MessagesDropped, 1)
cl.ops.hooks.OnPublishDropped(cl, pk)
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
cl.State.Inflight.IncreaseSendQuota()
return out, packets.ErrPendingClientWritesExceeded
}
return out, cl.WritePacket(out)
return out, nil
}
func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
@@ -847,12 +885,14 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
_, err := s.publishToClient(cl, sub, pkv)
if err != nil {
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
continue
}
s.hooks.OnRetainPublished(cl, pkv)
}
}
// buildAck builds an standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets.
// buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets.
func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet {
properties = packets.Properties{} // PRL
if reason.Code >= packets.ErrUnspecifiedError.Code {
@@ -881,7 +921,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error {
}
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
cl.State.Inflight.ReturnSendQuota()
cl.State.Inflight.IncreaseSendQuota()
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
}
@@ -904,7 +944,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error {
}
ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6]
cl.State.Inflight.TakeReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA
cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5]
return cl.WritePacket(ack)
}
@@ -931,8 +971,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
return err
}
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12]
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -944,8 +984,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server.
func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error {
// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas.
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
atomic.AddInt64(&s.Info.Inflight, -1)
s.hooks.OnQosComplete(cl, pk)
@@ -962,24 +1002,24 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
code = packets.ErrPacketIdentifierInUse
}
existed := false
filterExisted := make([]bool, len(pk.Filters))
reasonCodes := make([]byte, len(pk.Filters))
for i, sub := range pk.Filters {
if code != packets.CodeSuccess {
reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91
continue
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else if !s.hooks.OnACLCheck(cl, sub.Filter, false) {
reasonCodes[i] = packets.ErrNotAuthorized.Code
if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized {
reasonCodes[i] = packets.ErrUnspecifiedError.Code
}
} else if !IsValidFilter(sub.Filter, false) {
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
} else {
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if !existed {
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
if isNew {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
@@ -988,6 +1028,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
}
filterExisted[i] = !isNew
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
}
@@ -996,7 +1037,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
}
}
ack := packets.Packet{ //[MQTT-3.8.4-1] [MQTT-3.8.4-5]
ack := packets.Packet{ // [MQTT-3.8.4-1] [MQTT-3.8.4-5]
FixedHeader: packets.FixedHeader{
Type: packets.Suback,
},
@@ -1022,7 +1063,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
continue
}
s.publishRetainedToClient(cl, sub, existed)
s.publishRetainedToClient(cl, sub, filterExisted[i])
}
return nil
@@ -1072,20 +1113,32 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
return cl.WritePacket(ack)
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *Client) {
for k := range cl.State.Subscriptions.GetAll() {
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) UnsubscribeClient(cl *Client) {
i := 0
filterMap := cl.State.Subscriptions.GetAll()
filters := make([]packets.Subscription, len(filterMap))
for k := range filterMap {
cl.State.Subscriptions.Delete(k)
}
if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
return
}
for k, v := range filterMap {
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.Info.Subscriptions, -1)
}
filters[i] = v
i++
}
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
}
// processAuth processes an Auth packet.
func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
_, err := s.hooks.OnAuthPacket(cl, pk)
fmt.Println("err", err)
if err != nil {
return err
}
@@ -1173,6 +1226,7 @@ func (s *Server) publishSysTopics() {
SysPrefix + "/broker/packets/sent": AtomicItoa(&s.Info.PacketsSent),
SysPrefix + "/broker/messages/received": AtomicItoa(&s.Info.MessagesReceived),
SysPrefix + "/broker/messages/sent": AtomicItoa(&s.Info.MessagesSent),
SysPrefix + "/broker/messages/dropped": AtomicItoa(&s.Info.MessagesDropped),
SysPrefix + "/broker/messages/inflight": AtomicItoa(&s.Info.Inflight),
SysPrefix + "/broker/retained": AtomicItoa(&s.Info.Retained),
SysPrefix + "/broker/subscriptions": AtomicItoa(&s.Info.Subscriptions),
@@ -1190,12 +1244,10 @@ func (s *Server) publishSysTopics() {
s.hooks.OnSysInfoTick(s.Info)
}
// Close attempts to gracefully shutdown the server, all listeners, clients, and stores.
// Close attempts to gracefully shut down the server, all listeners, clients, and stores.
func (s *Server) Close() error {
close(s.done)
s.Listeners.CloseAll(s.closeListenerClients)
s.fanpool.Close()
s.fanpool.Wait()
s.hooks.OnStopped()
s.hooks.Stop()
@@ -1241,6 +1293,10 @@ func (s *Server) sendLWT(cl *Client) {
return
}
if pk.FixedHeader.Retain {
s.retainMessage(cl, pk)
}
s.publishToSubscribers(pk) // [MQTT-3.1.2-8]
atomic.StoreUint32(&cl.Properties.Will.Flag, 0) // [MQTT-3.1.2-10]
s.hooks.OnWillSent(cl, pk)
@@ -1315,6 +1371,7 @@ func (s *Server) loadServerInfo(v system.Info) {
atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected)
atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived)
atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent)
atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped)
atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived)
atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent)
atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped)
@@ -1372,25 +1429,7 @@ func (s *Server) loadClients(v []storage.Client) {
func (s *Server) loadInflight(v []storage.Message) {
for _, msg := range v {
if client, ok := s.Clients.Get(msg.Origin); ok {
client.State.Inflight.Set(packets.Packet{
FixedHeader: msg.FixedHeader,
PacketID: msg.PacketID,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
client.State.Inflight.Set(msg.ToPacket())
}
}
}
@@ -1398,24 +1437,7 @@ func (s *Server) loadInflight(v []storage.Message) {
// loadRetained restores retained messages from the datastore.
func (s *Server) loadRetained(v []storage.Message) {
for _, msg := range v {
s.Topics.RetainMessage(packets.Packet{
FixedHeader: msg.FixedHeader,
TopicName: msg.TopicName,
Payload: msg.Payload,
Origin: msg.Origin,
Created: msg.Created,
Properties: packets.Properties{
PayloadFormat: msg.Properties.PayloadFormat,
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
ContentType: msg.Properties.ContentType,
ResponseTopic: msg.Properties.ResponseTopic,
CorrelationData: msg.Properties.CorrelationData,
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
TopicAlias: msg.Properties.TopicAlias,
User: msg.Properties.User,
},
})
s.Topics.RetainMessage(msg.ToPacket())
}
}
@@ -1453,8 +1475,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
// clearExpiredInflights deletes any inflight messages which have expired.
func (s *Server) clearExpiredInflights(now int64) {
for _, client := range s.Clients.GetAll() {
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
s.hooks.OnExpireInflights(client, now)
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
for _, id := range deleted {
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
}
}
}
}
@@ -1465,6 +1489,9 @@ func (s *Server) sendDelayedLWT(dt int64) {
if dt > pk.Expiry {
s.publishToSubscribers(pk) // [MQTT-3.1.2-8]
if cl, ok := s.Clients.Get(id); ok {
if pk.FixedHeader.Retain {
s.retainMessage(cl, pk)
}
cl.Properties.Will = Will{} // [MQTT-3.1.2-10]
s.hooks.OnWillSent(cl, pk)
}

View File

@@ -48,16 +48,31 @@ func (h *AllowHook) Provides(b byte) bool {
func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true }
func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true }
type DelayHook struct {
HookBase
DisconnectDelay time.Duration
}
func (h *DelayHook) ID() string {
return "delay-hook"
}
func (h *DelayHook) Provides(b byte) bool {
return bytes.Contains([]byte{OnDisconnect}, []byte{b})
}
func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) {
time.Sleep(h.DisconnectDelay)
}
func newServer() *Server {
cc := *DefaultServerCapabilities
cc.MaximumMessageExpiryInterval = 0
cc.ReceiveMaximum = 0
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Capabilities: &cc,
Logger: &logger,
Capabilities: &cc,
})
s.AddHook(new(AllowHook), nil)
return s
@@ -68,8 +83,6 @@ func TestOptionsSetDefaults(t *testing.T) {
opts.ensureDefaults()
require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
require.Equal(t, defaultFanPoolSize, opts.FanPoolSize)
require.Equal(t, defaultFanPoolQueueSize, opts.FanPoolQueueSize)
require.Equal(t, DefaultServerCapabilities, opts.Capabilities)
opts = new(Options)
@@ -86,7 +99,6 @@ func TestNew(t *testing.T) {
require.NotNil(t, s.Info)
require.NotNil(t, s.Log)
require.NotNil(t, s.Options)
require.NotNil(t, s.fanpool)
require.NotNil(t, s.loop)
require.NotNil(t, s.loop.sysTopics)
require.NotNil(t, s.loop.inflightExpiry)
@@ -115,9 +127,9 @@ func TestServerNewClient(t *testing.T) {
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.conn)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.Equal(t, s.Log, cl.ops.log)
@@ -406,20 +418,22 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Properties.Username = []byte("mochi")
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
err := s.EstablishConnection("tcp", r)
o <- err
}()
go func() {
w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect.
w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
@@ -445,9 +459,14 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
}
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, <-recv)
connackPlusPacket := append(
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes...,
)
require.Equal(t, connackPlusPacket, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover)
time.Sleep(time.Microsecond * 100)
w.Close()
r.Close()
@@ -455,9 +474,99 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
require.True(t, ok)
require.NotEmpty(t, clw.State.Subscriptions)
sub, ok := cl.State.Subscriptions.Get("a/b/c")
// Prevent sequential takeover memory-bloom.
require.Empty(t, cl.State.Subscriptions.GetAll())
}
// See https://github.com/mochi-co/mqtt/issues/173
func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {
s := newServer()
d := new(DelayHook)
d.DisconnectDelay = time.Millisecond * 200
s.AddHook(d, nil)
defer s.Close()
// Clean session, 0 session expiry interval
cl1RawBytes := []byte{
packets.Connect << 4, 21, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
1 << 1, // Packet Flags
0, 30, // Keepalive
5, // Properties length
17, 0, 0, 0, 0, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
}
// Make first connection
r1, w1 := net.Pipe()
o1 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r1)
o1 <- err
}()
go func() {
w1.Write(cl1RawBytes)
}()
// receive the first connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w1)
require.NoError(t, err)
recv <- buf
}()
// Get the first client pointer
time.Sleep(time.Millisecond * 50)
cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.Equal(t, packets.Subscription{Filter: "a/b/c", Qos: 1}, sub)
cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0})
time.Sleep(time.Millisecond * 50)
// Make the second connection
r2, w2 := net.Pipe()
o2 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r2)
o2 <- err
}()
go func() {
x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:]
x[19] = '.' // differentiate username bytes in debugging
w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes)
}()
// receive the second connack
recv2 := make(chan []byte)
go func() {
buf, err := io.ReadAll(w2)
require.NoError(t, err)
recv2 <- buf
}()
// Capture first Client pointer
clp1, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Empty(t, clp1.Properties.Username)
require.NotEmpty(t, clp1.State.Subscriptions.GetAll())
err1 := <-o1
require.Error(t, err1)
require.ErrorIs(t, err1, io.ErrClosedPipe)
// Capture second Client pointer
clp2, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Equal(t, []byte(".ochi"), clp2.Properties.Username)
require.NotEmpty(t, clp2.State.Subscriptions.GetAll())
require.Empty(t, clp1.State.Subscriptions.GetAll())
w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
require.NoError(t, <-o2)
}
func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
@@ -553,9 +662,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
func TestEstablishConnectionBadAuthentication(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -589,9 +696,7 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) {
func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
defer s.Close()
@@ -643,6 +748,33 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) {
r.Close()
}
// See https://github.com/mochi-co/mqtt/issues/178
func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) {
s := newServer()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes)
w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the connack error
go func() {
_, err := io.ReadAll(w)
require.NoError(t, err)
}()
err := <-o
require.NoError(t, err)
r.Close()
}
func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) {
s := newServer()
@@ -685,17 +817,40 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) {
r.Close()
}
func TestServerEstablishConnectionOnConnectError(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
err := s.AddHook(hook, nil)
require.NoError(t, err)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
}()
err = <-o
require.Error(t, err)
require.ErrorIs(t, err, errTestHook)
r.Close()
}
func TestServerSendConnack(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
s.Options.Capabilities.MaximumQos = 1
cl.Properties.Props = packets.Properties{
AssignedClientID: "mochi",
}
go func() {
err := s.sendConnack(cl, packets.CodeSuccess, true)
err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
require.NoError(t, err)
w.Close()
}()
@@ -709,9 +864,8 @@ func TestServerSendConnackFailureReason(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.ServerKeepAlive = 20
go func() {
err := s.sendConnack(cl, packets.ErrUnspecifiedError, true)
err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil)
require.NoError(t, err)
w.Close()
}()
@@ -721,6 +875,23 @@ func TestServerSendConnackFailureReason(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf)
}
func TestServerSendConnackWithServerKeepalive(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Keepalive = 10
cl.State.ServerKeepalive = true
go func() {
err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
require.NoError(t, err)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf)
}
func TestServerValidateConnect(t *testing.T) {
packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet
invalidBitPacket := packet
@@ -790,7 +961,7 @@ func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) {
cl.Properties.Props.SessionExpiryInterval = uint32(300)
s.Options.Capabilities.MaximumSessionExpiryInterval = 120
go func() {
err := s.sendConnack(cl, packets.CodeSuccess, false)
err := s.SendConnack(cl, packets.CodeSuccess, false, nil)
require.NoError(t, err)
w.Close()
}()
@@ -806,7 +977,7 @@ func TestInheritClientSession(t *testing.T) {
n := time.Now().Unix()
existing, _, _ := newTestClient()
existing.Net.conn = nil
existing.Net.Conn = nil
existing.ID = "mochi"
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
existing.State.Inflight = NewInflights()
@@ -844,7 +1015,7 @@ func TestServerUnsubscribeClient(t *testing.T) {
s.Topics.Subscribe(cl.ID, pk)
subs := s.Topics.Subscribers("a/b/c")
require.Equal(t, 1, len(subs.Subscriptions))
s.unsubscribeClient(cl)
s.UnsubscribeClient(cl)
subs = s.Topics.Subscribers("a/b/c")
require.Equal(t, 0, len(subs.Subscriptions))
}
@@ -1023,7 +1194,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) {
w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-receiverBuf)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
}
@@ -1098,9 +1269,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) {
func TestServerProcessPublishACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
defer s.Close()
@@ -1383,6 +1552,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.FixedHeader.Qos = 2
s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx)
time.Sleep(time.Microsecond * 100)
w.Close()
}()
@@ -1396,6 +1566,33 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
}
func TestPublishToClientExceedClientWritesPending(t *testing.T) {
s := newServer()
_, w := net.Pipe()
cl := newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
options: &Options{
Capabilities: &Capabilities{
MaximumClientWritesPending: 3,
},
},
})
s.Clients.Add(cl)
for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ {
cl.State.outbound <- new(packets.Packet)
atomic.AddInt32(&cl.State.outboundQty, 1)
}
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{})
require.Error(t, err)
require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
}
func TestPublishToClientServerTopicAlias(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
@@ -1407,6 +1604,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
time.Sleep(time.Millisecond)
w.Close()
}()
@@ -1428,7 +1626,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) {
func TestPublishToClientExhaustedPacketID(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
for i := 0; i <= 65535; i++ {
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
@@ -1440,7 +1638,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) {
func TestPublishToClientNoConn(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Net.conn = nil
cl.Net.Conn = nil
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
@@ -1497,7 +1695,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
for i := 0; i <= 65535; i++ {
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: 1})
}
@@ -1537,7 +1735,7 @@ func TestPublishRetainedToClient(t *testing.T) {
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet)
require.Equal(t, int64(1), retained)
go func() {
@@ -1548,7 +1746,7 @@ func TestPublishRetainedToClient(t *testing.T) {
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf)
}
func TestPublishRetainedToClientIsShared(t *testing.T) {
@@ -1863,7 +2061,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) {
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w = net.Pipe()
cl.Net.conn = w
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
@@ -1937,7 +2135,8 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w := net.Pipe()
cl.Net.conn = w
time.Sleep(time.Millisecond)
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
@@ -1953,6 +2152,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) {
require.NoError(t, err)
}
time.Sleep(time.Millisecond)
w.Close()
if i != 2 {
@@ -2064,7 +2264,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) {
require.NoError(t, err)
require.Equal(t, append(
packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes...,
packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
), buf)
}
@@ -2164,9 +2364,7 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) {
func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
cl, r, w := newTestClient()
@@ -2185,9 +2383,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
s := New(&Options{
Logger: &logger,
FanPoolSize: 2,
FanPoolQueueSize: 10,
Logger: &logger,
})
s.Serve()
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
@@ -2319,7 +2515,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
@@ -2414,6 +2610,46 @@ func TestServerSendLWT(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestServerSendLWTRetain(t *testing.T) {
s := newServer()
s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.ID = "sender"
sender.Properties.Will = Will{
Flag: 1,
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
Retain: true,
}
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
s.sendLWT(sender)
time.Sleep(time.Millisecond * 10)
w1.Close()
w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-receiverBuf)
}
func TestServerSendLWTDelayed(t *testing.T) {
s := newServer()
cl1, _, _ := newTestClient()
@@ -2452,7 +2688,7 @@ func TestServerSendLWTDelayed(t *testing.T) {
recv <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-recv)
}
func TestServerReadStore(t *testing.T) {
@@ -2573,7 +2809,6 @@ func TestServerClose(t *testing.T) {
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
s.Serve()
require.Equal(t, uint64(2), s.fanpool.Size())
// receive the disconnect
recv := make(chan []byte)
@@ -2593,7 +2828,6 @@ func TestServerClose(t *testing.T) {
s.Close()
time.Sleep(time.Millisecond)
require.Equal(t, false, listener.(*listeners.MockListener).IsServing())
require.Equal(t, uint64(0), s.fanpool.Size())
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv)
}
@@ -2651,7 +2885,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.done = 1
cl0.State.cancelOpen()
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2661,7 +2895,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.done = 1
cl1.State.cancelOpen()
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
@@ -2671,7 +2905,7 @@ func TestServerClearExpiredClients(t *testing.T) {
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.done = 1
cl2.State.cancelOpen()
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true

View File

@@ -4,6 +4,8 @@
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
@@ -20,6 +22,7 @@ type Info struct {
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
@@ -29,3 +32,30 @@ type Info struct {
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

37
system/system_test.go Normal file
View File

@@ -0,0 +1,37 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

View File

@@ -301,6 +301,9 @@ func NewTopicsIndex() *TopicsIndex {
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
@@ -320,8 +323,13 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
prefix, _ := isolateParticle(filter, 0)
shareSub := strings.EqualFold(prefix, SharePrefix)
if shareSub {
d = 2
}
@@ -330,8 +338,7 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
return false
}
prefix, _ := isolateParticle(filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
if shareSub {
group, _ := isolateParticle(filter, 1)
particle.shared.Delete(group, client)
} else {
@@ -346,7 +353,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
@@ -361,6 +373,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
@@ -488,20 +501,27 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su
}
key, hasNext := isolateParticle(topic, d)
for _, partKey := range []string{key, "+", "#"} {
for _, partKey := range []string{key, "+"} {
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
}
if hasNext {
x.scanSubscribers(topic, d+1, particle, subs)
} else {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
x.gatherSharedSubscriptions(wild, subs)
}
}
}
}
if particle := n.particles.get("#"); particle != nil {
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
}
return subs
}
@@ -619,6 +639,7 @@ type particle struct {
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.

View File

@@ -319,7 +319,7 @@ func TestUnsubscribeShared(t *testing.T) {
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1"))
_, exists = final.shared.Get("tmp", "cl1")
require.False(t, exists)
}
@@ -501,28 +501,40 @@ func TestScanSubscribers(t *testing.T) {
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 4, len(subs.Subscriptions))
require.Equal(t, 3, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl1")
require.Contains(t, subs.Subscriptions, "cl2")
require.Contains(t, subs.Subscriptions, "cl3")
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers))
require.Equal(t, 1, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
require.Equal(t, 0, len(subs.Subscriptions))
}
func TestScanSubscribersTopicInheritanceBug(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 1, len(subs.Subscriptions))
}
func TestScanSubscribersShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
@@ -531,8 +543,9 @@ func TestScanSubscribersShared(t *testing.T) {
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 3, len(subs.Shared))
require.Equal(t, 4, len(subs.Shared))
}
func TestSelectSharedSubscriber(t *testing.T) {

View File

@@ -32,7 +32,7 @@ type histogram struct {
valueCount int64 // number of values recorded for single value
}
// AddMeasurement records a value measurement observation to the histogram.
// addMeasurement records a value measurement observation to the histogram.
func (h *histogram) addMeasurement(value int64) {
// TODO: assert invariant
h.sum += value

View File

@@ -395,7 +395,7 @@ func New(family, title string) Trace {
}
func (tr *trace) Finish() {
elapsed := time.Now().Sub(tr.Start)
elapsed := time.Since(tr.Start)
tr.mu.Lock()
tr.Elapsed = elapsed
tr.mu.Unlock()

View File

@@ -1,30 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unsafeheader contains header declarations for the Go runtime's
// slice and string implementations.
//
// This package allows x/sys to use types equivalent to
// reflect.SliceHeader and reflect.StringHeader without introducing
// a dependency on the (relatively heavy) "reflect" package.
package unsafeheader
import (
"unsafe"
)
// Slice is the runtime representation of a slice.
// It cannot be used safely or portably and its representation may change in a later release.
type Slice struct {
Data unsafe.Pointer
Len int
Cap int
}
// String is the runtime representation of a string.
// It cannot be used safely or portably and its representation may change in a later release.
type String struct {
Data unsafe.Pointer
Len int
}

31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s generated vendored Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
// +build darwin freebsd netbsd openbsd
// +build gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gccgo && !aix
// +build gccgo,!aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
package unix

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build gccgo
// +build !aix
//go:build gccgo && !aix && !hurd
// +build gccgo,!aix,!hurd
#include <errno.h>
#include <stdint.h>

View File

@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package unix

View File

@@ -4,9 +4,7 @@
package unix
import (
"unsafe"
)
import "unsafe"
// IoctlRetInt performs an ioctl operation specified by req on a device
// associated with opened file descriptor fd, and returns a non-negative
@@ -217,3 +215,19 @@ func IoctlKCMAttach(fd int, info KCMAttach) error {
func IoctlKCMUnattach(fd int, info KCMUnattach) error {
return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info))
}
// IoctlLoopGetStatus64 gets the status of the loop device associated with the
// file descriptor fd using the LOOP_GET_STATUS64 operation.
func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
var value LoopInfo64
if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil {
return nil, err
}
return &value, nil
}
// IoctlLoopSetStatus64 sets the status of the loop device associated with the
// file descriptor fd using the LOOP_SET_STATUS64 operation.
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
}

View File

@@ -73,12 +73,12 @@ aix_ppc64)
darwin_amd64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm_darwin.go"
mkasm="go run mkasm.go"
;;
darwin_arm64)
mkerrors="$mkerrors -m64"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
mkasm="go run mkasm_darwin.go"
mkasm="go run mkasm.go"
;;
dragonfly_amd64)
mkerrors="$mkerrors -m64"
@@ -142,42 +142,60 @@ netbsd_arm64)
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_386)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m32"
mksyscall="go run mksyscall.go -l32 -openbsd"
mksyscall="go run mksyscall.go -l32 -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_amd64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
;;
openbsd_arm)
mkasm="go run mkasm.go"
mkerrors="$mkerrors"
mksyscall="go run mksyscall.go -l32 -openbsd -arm"
mksyscall="go run mksyscall.go -l32 -openbsd -arm -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_arm64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_mips64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_ppc64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
;;
openbsd_riscv64)
mkasm="go run mkasm.go"
mkerrors="$mkerrors -m64"
mksyscall="go run mksyscall.go -openbsd -libc"
mksysctl="go run mksysctl_openbsd.go"
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
# Let the type of C char be signed for making the bare syscall
# API consistent across platforms.
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
@@ -214,11 +232,6 @@ esac
if [ "$GOOSARCH" == "aix_ppc64" ]; then
# aix/ppc64 script generates files instead of writing to stdin.
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ;
elif [ "$GOOS" == "darwin" ]; then
# 1.12 and later, syscalls via libSystem
echo "$mksyscall -tags $GOOS,$GOARCH,go1.12 $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go";
# 1.13 and later, syscalls via libSystem (including syscallPtr)
echo "$mksyscall -tags $GOOS,$GOARCH,go1.13 syscall_darwin.1_13.go |gofmt >zsyscall_$GOOSARCH.1_13.go";
elif [ "$GOOS" == "illumos" ]; then
# illumos code generation requires a --illumos switch
echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go";
@@ -232,5 +245,5 @@ esac
if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi
if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi
if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go | go run mkpost.go > ztypes_$GOOSARCH.go"; fi
if [ -n "$mkasm" ]; then echo "$mkasm $GOARCH"; fi
if [ -n "$mkasm" ]; then echo "$mkasm $GOOS $GOARCH"; fi
) | $run

View File

@@ -642,7 +642,7 @@ errors=$(
signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort
)
@@ -652,7 +652,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
sort >_signal.grep
echo '// mkerrors.sh' "$@"

View File

@@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
return msgs, nil
}
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
// message data (a slice of b), and the remainder of b after that single message.
// When there are no remaining messages, len(remainder) == 0.
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
h, dbuf, err := socketControlMessageHeaderAndData(b)
if err != nil {
return Cmsghdr{}, nil, nil, err
}
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
remainder = b[i:]
}
return *h, dbuf, remainder, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {

27
vendor/golang.org/x/sys/unix/str.go generated vendored
View File

@@ -1,27 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package unix
func itoa(val int) string { // do it here rather than with fmt to avoid dependency
if val < 0 {
return "-" + uitoa(uint(-val))
}
return uitoa(uint(val))
}
func uitoa(val uint) string {
var buf [32]byte // big enough for int64
i := len(buf) - 1
for val >= 10 {
buf[i] = byte(val%10 + '0')
i--
val /= 10
}
buf[i] = byte(val + '0')
return string(buf[i:])
}

View File

@@ -29,8 +29,6 @@ import (
"bytes"
"strings"
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
// ByteSliceFromString returns a NUL-terminated slice of bytes
@@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string {
ptr = unsafe.Pointer(uintptr(ptr) + 1)
}
var s []byte
h := (*unsafeheader.Slice)(unsafe.Pointer(&s))
h.Data = unsafe.Pointer(p)
h.Len = n
h.Cap = n
return string(s)
return string(unsafe.Slice(p, n))
}
// Single-word zero for use when we need a valid pointer to 0 bytes.

View File

@@ -253,7 +253,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var empty bool
if len(oob) > 0 {
// send at least one normal byte
empty := emptyIovecs(iov)
empty = emptyIovecs(iov)
if empty {
var iova [1]Iovec
iova[0].Base = &dummy

View File

@@ -363,7 +363,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var empty bool
if len(oob) > 0 {
// send at least one normal byte
empty := emptyIovecs(iov)
empty = emptyIovecs(iov)
if empty {
var iova [1]Iovec
iova[0].Base = &dummy

View File

@@ -1,32 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12 && !go1.13
// +build darwin,go1.12,!go1.13
package unix
import (
"unsafe"
)
const _SYS_GETDIRENTRIES64 = 344
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// To implement this using libSystem we'd need syscall_syscallPtr for
// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall
// back to raw syscalls for this func on Go 1.12.
var p unsafe.Pointer
if len(buf) > 0 {
p = unsafe.Pointer(&buf[0])
} else {
p = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0)
n = int(r0)
if e1 != 0 {
return n, errnoErr(e1)
}
return n, nil
}

View File

@@ -1,108 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.13
// +build darwin,go1.13
package unix
import (
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
var s []byte
hdr := (*unsafeheader.Slice)(unsafe.Pointer(&s))
hdr.Data = unsafe.Pointer(&entry)
hdr.Cap = reclen
hdr.Len = reclen
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}

View File

@@ -19,6 +19,96 @@ import (
"unsafe"
)
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}
// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets.
type SockaddrDatalink struct {
Len uint8
@@ -140,6 +230,7 @@ func direntNamlen(buf []byte) (uint64, bool) {
func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) }
func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) }
func PtraceDenyAttach() (err error) { return ptrace(PT_DENY_ATTACH, 0, 0, 0) }
//sysnb pipe(p *[2]int32) (err error)

View File

@@ -255,6 +255,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

View File

@@ -319,6 +319,7 @@ func PtraceSingleStep(pid int) (err error) {
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

View File

@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
}
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint32(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
}
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint32(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

View File

@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
ioDesc := PtraceIoDesc{
Op: int32(req),
Offs: offs,
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
Len: uint64(countin),
}
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
return int(ioDesc.Len), err
}

22
vendor/golang.org/x/sys/unix/syscall_hurd.go generated vendored Normal file
View File

@@ -0,0 +1,22 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build hurd
// +build hurd
package unix
/*
#include <stdint.h>
int ioctl(int, unsigned long int, uintptr_t);
*/
import "C"
func ioctl(fd int, req uint, arg uintptr) (err error) {
r0, er := C.ioctl(C.int(fd), C.ulong(req), C.uintptr_t(arg))
if r0 == -1 && er != nil {
err = er
}
return
}

29
vendor/golang.org/x/sys/unix/syscall_hurd_386.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 && hurd
// +build 386,hurd
package unix
const (
TIOCGETA = 0x62251713
)
type Winsize struct {
Row uint16
Col uint16
Xpixel uint16
Ypixel uint16
}
type Termios struct {
Iflag uint32
Oflag uint32
Cflag uint32
Lflag uint32
Cc [20]uint8
Ispeed int32
Ospeed int32
}

View File

@@ -10,8 +10,6 @@
package unix
import (
"fmt"
"runtime"
"unsafe"
)
@@ -79,107 +77,3 @@ func Accept4(fd int, flags int) (nfd int, sa Sockaddr, err error) {
}
return
}
//sys putmsg(fd int, clptr *strbuf, dataptr *strbuf, flags int) (err error)
func Putmsg(fd int, cl []byte, data []byte, flags int) (err error) {
var clp, datap *strbuf
if len(cl) > 0 {
clp = &strbuf{
Len: int32(len(cl)),
Buf: (*int8)(unsafe.Pointer(&cl[0])),
}
}
if len(data) > 0 {
datap = &strbuf{
Len: int32(len(data)),
Buf: (*int8)(unsafe.Pointer(&data[0])),
}
}
return putmsg(fd, clp, datap, flags)
}
//sys getmsg(fd int, clptr *strbuf, dataptr *strbuf, flags *int) (err error)
func Getmsg(fd int, cl []byte, data []byte) (retCl []byte, retData []byte, flags int, err error) {
var clp, datap *strbuf
if len(cl) > 0 {
clp = &strbuf{
Maxlen: int32(len(cl)),
Buf: (*int8)(unsafe.Pointer(&cl[0])),
}
}
if len(data) > 0 {
datap = &strbuf{
Maxlen: int32(len(data)),
Buf: (*int8)(unsafe.Pointer(&data[0])),
}
}
if err = getmsg(fd, clp, datap, &flags); err != nil {
return nil, nil, 0, err
}
if len(cl) > 0 {
retCl = cl[:clp.Len]
}
if len(data) > 0 {
retData = data[:datap.Len]
}
return retCl, retData, flags, nil
}
func IoctlSetIntRetInt(fd int, req uint, arg int) (int, error) {
return ioctlRet(fd, req, uintptr(arg))
}
func IoctlSetString(fd int, req uint, val string) error {
bs := make([]byte, len(val)+1)
copy(bs[:len(bs)-1], val)
err := ioctl(fd, req, uintptr(unsafe.Pointer(&bs[0])))
runtime.KeepAlive(&bs[0])
return err
}
// Lifreq Helpers
func (l *Lifreq) SetName(name string) error {
if len(name) >= len(l.Name) {
return fmt.Errorf("name cannot be more than %d characters", len(l.Name)-1)
}
for i := range name {
l.Name[i] = int8(name[i])
}
return nil
}
func (l *Lifreq) SetLifruInt(d int) {
*(*int)(unsafe.Pointer(&l.Lifru[0])) = d
}
func (l *Lifreq) GetLifruInt() int {
return *(*int)(unsafe.Pointer(&l.Lifru[0]))
}
func (l *Lifreq) SetLifruUint(d uint) {
*(*uint)(unsafe.Pointer(&l.Lifru[0])) = d
}
func (l *Lifreq) GetLifruUint() uint {
return *(*uint)(unsafe.Pointer(&l.Lifru[0]))
}
func IoctlLifreq(fd int, req uint, l *Lifreq) error {
return ioctl(fd, req, uintptr(unsafe.Pointer(l)))
}
// Strioctl Helpers
func (s *Strioctl) SetInt(i int) {
s.Len = int32(unsafe.Sizeof(i))
s.Dp = (*int8)(unsafe.Pointer(&i))
}
func IoctlSetStrioctlRetInt(fd int, req uint, s *Strioctl) (int, error) {
return ioctlRet(fd, req, uintptr(unsafe.Pointer(s)))
}

View File

@@ -13,6 +13,7 @@ package unix
import (
"encoding/binary"
"strconv"
"syscall"
"time"
"unsafe"
@@ -233,7 +234,7 @@ func Futimesat(dirfd int, path string, tv []Timeval) error {
func Futimes(fd int, tv []Timeval) (err error) {
// Believe it or not, this is the best we can do on Linux
// (and is what glibc does).
return Utimes("/proc/self/fd/"+itoa(fd), tv)
return Utimes("/proc/self/fd/"+strconv.Itoa(fd), tv)
}
const ImplementsGetwd = true
@@ -1541,7 +1542,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var dummy byte
var empty bool
if len(oob) > 0 {
empty := emptyIovecs(iov)
empty = emptyIovecs(iov)
if empty {
var sockType int
sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
@@ -1553,6 +1554,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var iova [1]Iovec
iova[0].Base = &dummy
iova[0].SetLen(1)
iov = iova[:]
}
}
msg.Control = &oob[0]
@@ -1798,6 +1800,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sysnb Capset(hdr *CapUserHeader, data *CapUserData) (err error)
//sys Chdir(path string) (err error)
//sys Chroot(path string) (err error)
//sys ClockAdjtime(clockid int32, buf *Timex) (state int, err error)
//sys ClockGetres(clockid int32, res *Timespec) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys ClockNanosleep(clockid int32, flags int, request *Timespec, remain *Timespec) (err error)
@@ -1891,17 +1894,28 @@ func PrctlRetInt(option int, arg2 uintptr, arg3 uintptr, arg4 uintptr, arg5 uint
return int(ret), nil
}
// issue 1435.
// On linux Setuid and Setgid only affects the current thread, not the process.
// This does not match what most callers expect so we must return an error
// here rather than letting the caller think that the call succeeded.
func Setuid(uid int) (err error) {
return EOPNOTSUPP
return syscall.Setuid(uid)
}
func Setgid(uid int) (err error) {
return EOPNOTSUPP
func Setgid(gid int) (err error) {
return syscall.Setgid(gid)
}
func Setreuid(ruid, euid int) (err error) {
return syscall.Setreuid(ruid, euid)
}
func Setregid(rgid, egid int) (err error) {
return syscall.Setregid(rgid, egid)
}
func Setresuid(ruid, euid, suid int) (err error) {
return syscall.Setresuid(ruid, euid, suid)
}
func Setresgid(rgid, egid, sgid int) (err error) {
return syscall.Setresgid(rgid, egid, sgid)
}
// SetfsgidRetGid sets fsgid for current thread and returns previous fsgid set.
@@ -1960,36 +1974,46 @@ func Signalfd(fd int, sigmask *Sigset_t, flags int) (newfd int, err error) {
//sys preadv2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PREADV2
//sys pwritev2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PWRITEV2
func bytes2iovec(bs [][]byte) []Iovec {
iovecs := make([]Iovec, len(bs))
for i, b := range bs {
iovecs[i].SetLen(len(b))
// minIovec is the size of the small initial allocation used by
// Readv, Writev, etc.
//
// This small allocation gets stack allocated, which lets the
// common use case of len(iovs) <= minIovs avoid more expensive
// heap allocations.
const minIovec = 8
// appendBytes converts bs to Iovecs and appends them to vecs.
func appendBytes(vecs []Iovec, bs [][]byte) []Iovec {
for _, b := range bs {
var v Iovec
v.SetLen(len(b))
if len(b) > 0 {
iovecs[i].Base = &b[0]
v.Base = &b[0]
} else {
iovecs[i].Base = (*byte)(unsafe.Pointer(&_zero))
v.Base = (*byte)(unsafe.Pointer(&_zero))
}
vecs = append(vecs, v)
}
return iovecs
return vecs
}
// offs2lohi splits offs into its lower and upper unsigned long. On 64-bit
// systems, hi will always be 0. On 32-bit systems, offs will be split in half.
// preadv/pwritev chose this calling convention so they don't need to add a
// padding-register for alignment on ARM.
// offs2lohi splits offs into its low and high order bits.
func offs2lohi(offs int64) (lo, hi uintptr) {
return uintptr(offs), uintptr(uint64(offs) >> SizeofLong)
const longBits = SizeofLong * 8
return uintptr(offs), uintptr(uint64(offs) >> (longBits - 1) >> 1) // two shifts to avoid false positive in vet
}
func Readv(fd int, iovs [][]byte) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
n, err = readv(fd, iovecs)
readvRacedetect(iovecs, n, err)
return n, err
}
func Preadv(fd int, iovs [][]byte, offset int64) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
lo, hi := offs2lohi(offset)
n, err = preadv(fd, iovecs, lo, hi)
readvRacedetect(iovecs, n, err)
@@ -1997,7 +2021,8 @@ func Preadv(fd int, iovs [][]byte, offset int64) (n int, err error) {
}
func Preadv2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
lo, hi := offs2lohi(offset)
n, err = preadv2(fd, iovecs, lo, hi, flags)
readvRacedetect(iovecs, n, err)
@@ -2024,7 +2049,8 @@ func readvRacedetect(iovecs []Iovec, n int, err error) {
}
func Writev(fd int, iovs [][]byte) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
if raceenabled {
raceReleaseMerge(unsafe.Pointer(&ioSync))
}
@@ -2034,7 +2060,8 @@ func Writev(fd int, iovs [][]byte) (n int, err error) {
}
func Pwritev(fd int, iovs [][]byte, offset int64) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
if raceenabled {
raceReleaseMerge(unsafe.Pointer(&ioSync))
}
@@ -2045,7 +2072,8 @@ func Pwritev(fd int, iovs [][]byte, offset int64) (n int, err error) {
}
func Pwritev2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) {
iovecs := bytes2iovec(iovs)
iovecs := make([]Iovec, 0, minIovec)
iovecs = appendBytes(iovecs, iovs)
if raceenabled {
raceReleaseMerge(unsafe.Pointer(&ioSync))
}
@@ -2240,7 +2268,7 @@ func (fh *FileHandle) Bytes() []byte {
if n == 0 {
return nil
}
return (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&fh.fileHandle.Type)) + 4))[:n:n]
return unsafe.Slice((*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&fh.fileHandle.Type))+4)), n)
}
// NameToHandleAt wraps the name_to_handle_at system call; it obtains
@@ -2356,6 +2384,16 @@ func Setitimer(which ItimerWhich, it Itimerval) (Itimerval, error) {
return prev, nil
}
//sysnb rtSigprocmask(how int, set *Sigset_t, oldset *Sigset_t, sigsetsize uintptr) (err error) = SYS_RT_SIGPROCMASK
func PthreadSigmask(how int, set, oldset *Sigset_t) error {
if oldset != nil {
// Explicitly clear in case Sigset_t is larger than _C__NSIG.
*oldset = Sigset_t{}
}
return rtSigprocmask(how, set, oldset, _C__NSIG/8)
}
/*
* Unimplemented
*/
@@ -2414,7 +2452,6 @@ func Setitimer(which ItimerWhich, it Itimerval) (Itimerval, error) {
// RestartSyscall
// RtSigaction
// RtSigpending
// RtSigprocmask
// RtSigqueueinfo
// RtSigreturn
// RtSigsuspend

View File

@@ -41,10 +41,6 @@ func setTimeval(sec, usec int64) Timeval {
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64
//sys setfsgid(gid int) (prev int, err error) = SYS_SETFSGID32
//sys setfsuid(uid int) (prev int, err error) = SYS_SETFSUID32
//sysnb Setregid(rgid int, egid int) (err error) = SYS_SETREGID32
//sysnb Setresgid(rgid int, egid int, sgid int) (err error) = SYS_SETRESGID32
//sysnb Setresuid(ruid int, euid int, suid int) (err error) = SYS_SETRESUID32
//sysnb Setreuid(ruid int, euid int) (err error) = SYS_SETREUID32
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error)
//sys Stat(path string, stat *Stat_t) (err error) = SYS_STAT64
//sys SyncFileRange(fd int, off int64, n int64, flags int) (err error)

View File

@@ -46,11 +46,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)

View File

@@ -62,10 +62,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, err error) {
//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) = SYS__NEWSELECT
//sys setfsgid(gid int) (prev int, err error) = SYS_SETFSGID32
//sys setfsuid(uid int) (prev int, err error) = SYS_SETFSUID32
//sysnb Setregid(rgid int, egid int) (err error) = SYS_SETREGID32
//sysnb Setresgid(rgid int, egid int, sgid int) (err error) = SYS_SETRESGID32
//sysnb Setresuid(ruid int, euid int, suid int) (err error) = SYS_SETRESUID32
//sysnb Setreuid(ruid int, euid int) (err error) = SYS_SETREUID32
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error)
//sys Stat(path string, stat *Stat_t) (err error) = SYS_STAT64

View File

@@ -39,11 +39,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)

View File

@@ -34,10 +34,6 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)

View File

@@ -37,11 +37,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)
//sys Statfs(path string, buf *Statfs_t) (err error)

View File

@@ -32,10 +32,6 @@ func Syscall9(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr,
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error)
//sys SyncFileRange(fd int, off int64, n int64, flags int) (err error)

View File

@@ -34,10 +34,6 @@ import (
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error)
//sys Stat(path string, stat *Stat_t) (err error) = SYS_STAT64

View File

@@ -34,11 +34,7 @@ package unix
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)
//sys Stat(path string, stat *Stat_t) (err error)

View File

@@ -38,11 +38,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)

View File

@@ -34,11 +34,7 @@ import (
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)
//sys Stat(path string, stat *Stat_t) (err error)
//sys Statfs(path string, buf *Statfs_t) (err error)

View File

@@ -31,11 +31,7 @@ package unix
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys setfsgid(gid int) (prev int, err error)
//sys setfsuid(uid int) (prev int, err error)
//sysnb Setregid(rgid int, egid int) (err error)
//sysnb Setresgid(rgid int, egid int, sgid int) (err error)
//sysnb Setresuid(ruid int, euid int, suid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Setreuid(ruid int, euid int) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)
//sys Stat(path string, stat *Stat_t) (err error)

View File

@@ -110,6 +110,20 @@ func direntNamlen(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Namlen), unsafe.Sizeof(Dirent{}.Namlen))
}
func SysctlUvmexp(name string) (*Uvmexp, error) {
mib, err := sysctlmib(name)
if err != nil {
return nil, err
}
n := uintptr(SizeofUvmexp)
var u Uvmexp
if err := sysctl(mib, (*byte)(unsafe.Pointer(&u)), &n, nil, 0); err != nil {
return nil, err
}
return &u, nil
}
func Pipe(p []int) (err error) {
return Pipe2(p, 0)
}
@@ -245,6 +259,7 @@ func Statvfs(path string, buf *Statvfs_t) (err error) {
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

View File

@@ -220,6 +220,7 @@ func Uname(uname *Utsname) error {
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Dup(fd int) (nfd int, err error)
//sys Dup2(from int, to int) (err error)

27
vendor/golang.org/x/sys/unix/syscall_openbsd_libc.go generated vendored Normal file
View File

@@ -0,0 +1,27 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build openbsd
// +build openbsd
package unix
import _ "unsafe"
// Implemented in the runtime package (runtime/sys_openbsd3.go)
func syscall_syscall(fn, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
func syscall_syscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno)
func syscall_syscall10(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 uintptr) (r1, r2 uintptr, err Errno)
func syscall_rawSyscall(fn, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno)
func syscall_rawSyscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno)
//go:linkname syscall_syscall syscall.syscall
//go:linkname syscall_syscall6 syscall.syscall6
//go:linkname syscall_syscall10 syscall.syscall10
//go:linkname syscall_rawSyscall syscall.rawSyscall
//go:linkname syscall_rawSyscall6 syscall.rawSyscall6
func syscall_syscall9(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err Errno) {
return syscall_syscall10(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9, 0)
}

42
vendor/golang.org/x/sys/unix/syscall_openbsd_ppc64.go generated vendored Normal file
View File

@@ -0,0 +1,42 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ppc64 && openbsd
// +build ppc64,openbsd
package unix
func setTimespec(sec, nsec int64) Timespec {
return Timespec{Sec: sec, Nsec: nsec}
}
func setTimeval(sec, usec int64) Timeval {
return Timeval{Sec: sec, Usec: usec}
}
func SetKevent(k *Kevent_t, fd, mode, flags int) {
k.Ident = uint64(fd)
k.Filter = int16(mode)
k.Flags = uint16(flags)
}
func (iov *Iovec) SetLen(length int) {
iov.Len = uint64(length)
}
func (msghdr *Msghdr) SetControllen(length int) {
msghdr.Controllen = uint32(length)
}
func (msghdr *Msghdr) SetIovlen(length int) {
msghdr.Iovlen = uint32(length)
}
func (cmsg *Cmsghdr) SetLen(length int) {
cmsg.Len = uint32(length)
}
// SYS___SYSCTL is used by syscall_bsd.go for all BSDs, but in modern versions
// of openbsd/ppc64 the syscall is called sysctl instead of __sysctl.
const SYS___SYSCTL = SYS_SYSCTL

View File

@@ -0,0 +1,42 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build riscv64 && openbsd
// +build riscv64,openbsd
package unix
func setTimespec(sec, nsec int64) Timespec {
return Timespec{Sec: sec, Nsec: nsec}
}
func setTimeval(sec, usec int64) Timeval {
return Timeval{Sec: sec, Usec: usec}
}
func SetKevent(k *Kevent_t, fd, mode, flags int) {
k.Ident = uint64(fd)
k.Filter = int16(mode)
k.Flags = uint16(flags)
}
func (iov *Iovec) SetLen(length int) {
iov.Len = uint64(length)
}
func (msghdr *Msghdr) SetControllen(length int) {
msghdr.Controllen = uint32(length)
}
func (msghdr *Msghdr) SetIovlen(length int) {
msghdr.Iovlen = uint32(length)
}
func (cmsg *Cmsghdr) SetLen(length int) {
cmsg.Len = uint32(length)
}
// SYS___SYSCTL is used by syscall_bsd.go for all BSDs, but in modern versions
// of openbsd/riscv64 the syscall is called sysctl instead of __sysctl.
const SYS___SYSCTL = SYS_SYSCTL

View File

@@ -590,6 +590,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sys Chmod(path string, mode uint32) (err error)
//sys Chown(path string, uid int, gid int) (err error)
//sys Chroot(path string) (err error)
//sys ClockGettime(clockid int32, time *Timespec) (err error)
//sys Close(fd int) (err error)
//sys Creat(path string, mode uint32) (fd int, err error)
//sys Dup(fd int) (nfd int, err error)
@@ -750,8 +751,8 @@ type EventPort struct {
// we should handle things gracefully. To do so, we need to keep an extra
// reference to the cookie around until the event is processed
// thus the otherwise seemingly extraneous "cookies" map
// The key of this map is a pointer to the corresponding &fCookie.cookie
cookies map[*interface{}]*fileObjCookie
// The key of this map is a pointer to the corresponding fCookie
cookies map[*fileObjCookie]struct{}
}
// PortEvent is an abstraction of the port_event C struct.
@@ -778,7 +779,7 @@ func NewEventPort() (*EventPort, error) {
port: port,
fds: make(map[uintptr]*fileObjCookie),
paths: make(map[string]*fileObjCookie),
cookies: make(map[*interface{}]*fileObjCookie),
cookies: make(map[*fileObjCookie]struct{}),
}
return e, nil
}
@@ -799,6 +800,7 @@ func (e *EventPort) Close() error {
}
e.fds = nil
e.paths = nil
e.cookies = nil
return nil
}
@@ -826,17 +828,16 @@ func (e *EventPort) AssociatePath(path string, stat os.FileInfo, events int, coo
if _, found := e.paths[path]; found {
return fmt.Errorf("%v is already associated with this Event Port", path)
}
fobj, err := createFileObj(path, stat)
fCookie, err := createFileObjCookie(path, stat, cookie)
if err != nil {
return err
}
fCookie := &fileObjCookie{fobj, cookie}
_, err = port_associate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(fobj)), events, (*byte)(unsafe.Pointer(&fCookie.cookie)))
_, err = port_associate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(fCookie.fobj)), events, (*byte)(unsafe.Pointer(fCookie)))
if err != nil {
return err
}
e.paths[path] = fCookie
e.cookies[&fCookie.cookie] = fCookie
e.cookies[fCookie] = struct{}{}
return nil
}
@@ -858,7 +859,7 @@ func (e *EventPort) DissociatePath(path string) error {
if err == nil {
// dissociate was successful, safe to delete the cookie
fCookie := e.paths[path]
delete(e.cookies, &fCookie.cookie)
delete(e.cookies, fCookie)
}
delete(e.paths, path)
return err
@@ -871,13 +872,16 @@ func (e *EventPort) AssociateFd(fd uintptr, events int, cookie interface{}) erro
if _, found := e.fds[fd]; found {
return fmt.Errorf("%v is already associated with this Event Port", fd)
}
fCookie := &fileObjCookie{nil, cookie}
_, err := port_associate(e.port, PORT_SOURCE_FD, fd, events, (*byte)(unsafe.Pointer(&fCookie.cookie)))
fCookie, err := createFileObjCookie("", nil, cookie)
if err != nil {
return err
}
_, err = port_associate(e.port, PORT_SOURCE_FD, fd, events, (*byte)(unsafe.Pointer(fCookie)))
if err != nil {
return err
}
e.fds[fd] = fCookie
e.cookies[&fCookie.cookie] = fCookie
e.cookies[fCookie] = struct{}{}
return nil
}
@@ -896,27 +900,31 @@ func (e *EventPort) DissociateFd(fd uintptr) error {
if err == nil {
// dissociate was successful, safe to delete the cookie
fCookie := e.fds[fd]
delete(e.cookies, &fCookie.cookie)
delete(e.cookies, fCookie)
}
delete(e.fds, fd)
return err
}
func createFileObj(name string, stat os.FileInfo) (*fileObj, error) {
fobj := new(fileObj)
bs, err := ByteSliceFromString(name)
if err != nil {
return nil, err
func createFileObjCookie(name string, stat os.FileInfo, cookie interface{}) (*fileObjCookie, error) {
fCookie := new(fileObjCookie)
fCookie.cookie = cookie
if name != "" && stat != nil {
fCookie.fobj = new(fileObj)
bs, err := ByteSliceFromString(name)
if err != nil {
return nil, err
}
fCookie.fobj.Name = (*int8)(unsafe.Pointer(&bs[0]))
s := stat.Sys().(*syscall.Stat_t)
fCookie.fobj.Atim.Sec = s.Atim.Sec
fCookie.fobj.Atim.Nsec = s.Atim.Nsec
fCookie.fobj.Mtim.Sec = s.Mtim.Sec
fCookie.fobj.Mtim.Nsec = s.Mtim.Nsec
fCookie.fobj.Ctim.Sec = s.Ctim.Sec
fCookie.fobj.Ctim.Nsec = s.Ctim.Nsec
}
fobj.Name = (*int8)(unsafe.Pointer(&bs[0]))
s := stat.Sys().(*syscall.Stat_t)
fobj.Atim.Sec = s.Atim.Sec
fobj.Atim.Nsec = s.Atim.Nsec
fobj.Mtim.Sec = s.Mtim.Sec
fobj.Mtim.Nsec = s.Mtim.Nsec
fobj.Ctim.Sec = s.Ctim.Sec
fobj.Ctim.Nsec = s.Ctim.Nsec
return fobj, nil
return fCookie, nil
}
// GetOne wraps port_get(3c) and returns a single PortEvent.
@@ -929,44 +937,50 @@ func (e *EventPort) GetOne(t *Timespec) (*PortEvent, error) {
p := new(PortEvent)
e.mu.Lock()
defer e.mu.Unlock()
e.peIntToExt(pe, p)
err = e.peIntToExt(pe, p)
if err != nil {
return nil, err
}
return p, nil
}
// peIntToExt converts a cgo portEvent struct into the friendlier PortEvent
// NOTE: Always call this function while holding the e.mu mutex
func (e *EventPort) peIntToExt(peInt *portEvent, peExt *PortEvent) {
func (e *EventPort) peIntToExt(peInt *portEvent, peExt *PortEvent) error {
if e.cookies == nil {
return fmt.Errorf("this EventPort is already closed")
}
peExt.Events = peInt.Events
peExt.Source = peInt.Source
cookie := (*interface{})(unsafe.Pointer(peInt.User))
peExt.Cookie = *cookie
fCookie := (*fileObjCookie)(unsafe.Pointer(peInt.User))
_, found := e.cookies[fCookie]
if !found {
panic("unexpected event port address; may be due to kernel bug; see https://go.dev/issue/54254")
}
peExt.Cookie = fCookie.cookie
delete(e.cookies, fCookie)
switch peInt.Source {
case PORT_SOURCE_FD:
delete(e.cookies, cookie)
peExt.Fd = uintptr(peInt.Object)
// Only remove the fds entry if it exists and this cookie matches
if fobj, ok := e.fds[peExt.Fd]; ok {
if &fobj.cookie == cookie {
if fobj == fCookie {
delete(e.fds, peExt.Fd)
}
}
case PORT_SOURCE_FILE:
if fCookie, ok := e.cookies[cookie]; ok && uintptr(unsafe.Pointer(fCookie.fobj)) == uintptr(peInt.Object) {
// Use our stashed reference rather than using unsafe on what we got back
// the unsafe version would be (*fileObj)(unsafe.Pointer(uintptr(peInt.Object)))
peExt.fobj = fCookie.fobj
} else {
panic("mismanaged memory")
}
delete(e.cookies, cookie)
peExt.fobj = fCookie.fobj
peExt.Path = BytePtrToString((*byte)(unsafe.Pointer(peExt.fobj.Name)))
// Only remove the paths entry if it exists and this cookie matches
if fobj, ok := e.paths[peExt.Path]; ok {
if &fobj.cookie == cookie {
if fobj == fCookie {
delete(e.paths, peExt.Path)
}
}
}
return nil
}
// Pending wraps port_getn(3c) and returns how many events are pending.
@@ -990,7 +1004,7 @@ func (e *EventPort) Get(s []PortEvent, min int, timeout *Timespec) (int, error)
got := uint32(min)
max := uint32(len(s))
var err error
ps := make([]portEvent, max, max)
ps := make([]portEvent, max)
_, err = port_getn(e.port, &ps[0], max, &got, timeout)
// got will be trustworthy with ETIME, but not any other error.
if err != nil && err != ETIME {
@@ -998,8 +1012,122 @@ func (e *EventPort) Get(s []PortEvent, min int, timeout *Timespec) (int, error)
}
e.mu.Lock()
defer e.mu.Unlock()
valid := 0
for i := 0; i < int(got); i++ {
e.peIntToExt(&ps[i], &s[i])
err2 := e.peIntToExt(&ps[i], &s[i])
if err2 != nil {
if valid == 0 && err == nil {
// If err2 is the only error and there are no valid events
// to return, return it to the caller.
err = err2
}
break
}
valid = i + 1
}
return int(got), err
return valid, err
}
//sys putmsg(fd int, clptr *strbuf, dataptr *strbuf, flags int) (err error)
func Putmsg(fd int, cl []byte, data []byte, flags int) (err error) {
var clp, datap *strbuf
if len(cl) > 0 {
clp = &strbuf{
Len: int32(len(cl)),
Buf: (*int8)(unsafe.Pointer(&cl[0])),
}
}
if len(data) > 0 {
datap = &strbuf{
Len: int32(len(data)),
Buf: (*int8)(unsafe.Pointer(&data[0])),
}
}
return putmsg(fd, clp, datap, flags)
}
//sys getmsg(fd int, clptr *strbuf, dataptr *strbuf, flags *int) (err error)
func Getmsg(fd int, cl []byte, data []byte) (retCl []byte, retData []byte, flags int, err error) {
var clp, datap *strbuf
if len(cl) > 0 {
clp = &strbuf{
Maxlen: int32(len(cl)),
Buf: (*int8)(unsafe.Pointer(&cl[0])),
}
}
if len(data) > 0 {
datap = &strbuf{
Maxlen: int32(len(data)),
Buf: (*int8)(unsafe.Pointer(&data[0])),
}
}
if err = getmsg(fd, clp, datap, &flags); err != nil {
return nil, nil, 0, err
}
if len(cl) > 0 {
retCl = cl[:clp.Len]
}
if len(data) > 0 {
retData = data[:datap.Len]
}
return retCl, retData, flags, nil
}
func IoctlSetIntRetInt(fd int, req uint, arg int) (int, error) {
return ioctlRet(fd, req, uintptr(arg))
}
func IoctlSetString(fd int, req uint, val string) error {
bs := make([]byte, len(val)+1)
copy(bs[:len(bs)-1], val)
err := ioctl(fd, req, uintptr(unsafe.Pointer(&bs[0])))
runtime.KeepAlive(&bs[0])
return err
}
// Lifreq Helpers
func (l *Lifreq) SetName(name string) error {
if len(name) >= len(l.Name) {
return fmt.Errorf("name cannot be more than %d characters", len(l.Name)-1)
}
for i := range name {
l.Name[i] = int8(name[i])
}
return nil
}
func (l *Lifreq) SetLifruInt(d int) {
*(*int)(unsafe.Pointer(&l.Lifru[0])) = d
}
func (l *Lifreq) GetLifruInt() int {
return *(*int)(unsafe.Pointer(&l.Lifru[0]))
}
func (l *Lifreq) SetLifruUint(d uint) {
*(*uint)(unsafe.Pointer(&l.Lifru[0])) = d
}
func (l *Lifreq) GetLifruUint() uint {
return *(*uint)(unsafe.Pointer(&l.Lifru[0]))
}
func IoctlLifreq(fd int, req uint, l *Lifreq) error {
return ioctl(fd, req, uintptr(unsafe.Pointer(l)))
}
// Strioctl Helpers
func (s *Strioctl) SetInt(i int) {
s.Len = int32(unsafe.Sizeof(i))
s.Dp = (*int8)(unsafe.Pointer(&i))
}
func IoctlSetStrioctlRetInt(fd int, req uint, s *Strioctl) (int, error) {
return ioctlRet(fd, req, uintptr(unsafe.Pointer(s)))
}

View File

@@ -13,8 +13,6 @@ import (
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
var (
@@ -117,11 +115,7 @@ func (m *mmapper) Mmap(fd int, offset int64, length int, prot int, flags int) (d
}
// Use unsafe to convert addr into a []byte.
var b []byte
hdr := (*unsafeheader.Slice)(unsafe.Pointer(&b))
hdr.Data = unsafe.Pointer(addr)
hdr.Cap = length
hdr.Len = length
b := unsafe.Slice((*byte)(unsafe.Pointer(addr)), length)
// Register mapping in m and return it.
p := &b[cap(b)-1]
@@ -337,6 +331,19 @@ func Recvfrom(fd int, p []byte, flags int) (n int, from Sockaddr, err error) {
return
}
// Recvmsg receives a message from a socket using the recvmsg system call. The
// received non-control data will be written to p, and any "out of band"
// control data will be written to oob. The flags are passed to recvmsg.
//
// The results are:
// - n is the number of non-control data bytes read into p
// - oobn is the number of control data bytes read into oob; this may be interpreted using [ParseSocketControlMessage]
// - recvflags is flags returned by recvmsg
// - from is the address of the sender
//
// If the underlying socket type is not SOCK_DGRAM, a received message
// containing oob data and a single '\0' of non-control data is treated as if
// the message contained only control data, i.e. n will be zero on return.
func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
var iov [1]Iovec
if len(p) > 0 {
@@ -352,13 +359,9 @@ func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from
return
}
// RecvmsgBuffers receives a message from a socket using the recvmsg
// system call. The flags are passed to recvmsg. Any non-control data
// read is scattered into the buffers slices. The results are:
// - n is the number of non-control data read into bufs
// - oobn is the number of control data read into oob; this may be interpreted using [ParseSocketControlMessage]
// - recvflags is flags returned by recvmsg
// - from is the address of the sender
// RecvmsgBuffers receives a message from a socket using the recvmsg system
// call. This function is equivalent to Recvmsg, but non-control data read is
// scattered into the buffers slices.
func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
iov := make([]Iovec, len(buffers))
for i := range buffers {
@@ -377,11 +380,38 @@ func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn in
return
}
// Sendmsg sends a message on a socket to an address using the sendmsg system
// call. This function is equivalent to SendmsgN, but does not return the
// number of bytes actually sent.
func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) {
_, err = SendmsgN(fd, p, oob, to, flags)
return
}
// SendmsgN sends a message on a socket to an address using the sendmsg system
// call. p contains the non-control data to send, and oob contains the "out of
// band" control data. The flags are passed to sendmsg. The number of
// non-control bytes actually written to the socket is returned.
//
// Some socket types do not support sending control data without accompanying
// non-control data. If p is empty, and oob contains control data, and the
// underlying socket type is not SOCK_DGRAM, p will be treated as containing a
// single '\0' and the return value will indicate zero bytes sent.
//
// The Go function Recvmsg, if called with an empty p and a non-empty oob,
// will read and ignore this additional '\0'. If the message is received by
// code that does not use Recvmsg, or that does not use Go at all, that code
// will need to be written to expect and ignore the additional '\0'.
//
// If you need to send non-empty oob with p actually empty, and if the
// underlying socket type supports it, you can do so via a raw system call as
// follows:
//
// msg := &unix.Msghdr{
// Control: &oob[0],
// }
// msg.SetControllen(len(oob))
// n, _, errno := unix.Syscall(unix.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), flags)
func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) {
var iov [1]Iovec
if len(p) > 0 {
@@ -400,9 +430,8 @@ func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error)
}
// SendmsgBuffers sends a message on a socket to an address using the sendmsg
// system call. The flags are passed to sendmsg. Any non-control data written
// is gathered from buffers. The function returns the number of bytes written
// to the socket.
// system call. This function is equivalent to SendmsgN, but the non-control
// data is gathered from buffers.
func SendmsgBuffers(fd int, buffers [][]byte, oob []byte, to Sockaddr, flags int) (n int, err error) {
iov := make([]Iovec, len(buffers))
for i := range buffers {
@@ -429,11 +458,15 @@ func Send(s int, buf []byte, flags int) (err error) {
}
func Sendto(fd int, p []byte, flags int, to Sockaddr) (err error) {
ptr, n, err := to.sockaddr()
if err != nil {
return err
var ptr unsafe.Pointer
var salen _Socklen
if to != nil {
ptr, salen, err = to.sockaddr()
if err != nil {
return err
}
}
return sendto(fd, p, flags, ptr, n)
return sendto(fd, p, flags, ptr, salen)
}
func SetsockoptByte(fd, level, opt int, value byte) (err error) {
@@ -545,7 +578,7 @@ func Lutimes(path string, tv []Timeval) error {
return UtimesNanoAt(AT_FDCWD, path, ts, AT_SYMLINK_NOFOLLOW)
}
// emptyIovec reports whether there are no bytes in the slice of Iovec.
// emptyIovecs reports whether there are no bytes in the slice of Iovec.
func emptyIovecs(iov []Iovec) bool {
for i := range iov {
if iov[i].Len > 0 {

View File

@@ -2,11 +2,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris) && gc && !ppc64le && !ppc64
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
//go:build (darwin || dragonfly || freebsd || (linux && !ppc64 && !ppc64le) || netbsd || openbsd || solaris) && gc
// +build darwin dragonfly freebsd linux,!ppc64,!ppc64le netbsd openbsd solaris
// +build gc
// +build !ppc64le
// +build !ppc64
package unix

View File

@@ -9,8 +9,10 @@ package unix
import (
"bytes"
"fmt"
"runtime"
"sort"
"strings"
"sync"
"syscall"
"unsafe"
@@ -55,7 +57,13 @@ func (d *Dirent) NameString() string {
if d == nil {
return ""
}
return string(d.Name[:d.Namlen])
s := string(d.Name[:])
idx := strings.IndexByte(s, 0)
if idx == -1 {
return s
} else {
return s[:idx]
}
}
func (sa *SockaddrInet4) sockaddr() (unsafe.Pointer, _Socklen, error) {
@@ -1230,6 +1238,14 @@ func Readdir(dir uintptr) (*Dirent, error) {
return &ent, err
}
func readdir_r(dirp uintptr, entry *direntLE, result **direntLE) (err error) {
r0, _, e1 := syscall_syscall(SYS___READDIR_R_A, dirp, uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result)))
if int64(r0) == -1 {
err = errnoErr(Errno(e1))
}
return
}
func Closedir(dir uintptr) error {
_, _, e := syscall_syscall(SYS_CLOSEDIR, dir, 0, 0)
if e != 0 {
@@ -1821,3 +1837,158 @@ func Unmount(name string, mtm int) (err error) {
}
return err
}
func fdToPath(dirfd int) (path string, err error) {
var buffer [1024]byte
// w_ctrl()
ret := runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS_W_IOCTL<<4,
[]uintptr{uintptr(dirfd), 17, 1024, uintptr(unsafe.Pointer(&buffer[0]))})
if ret == 0 {
zb := bytes.IndexByte(buffer[:], 0)
if zb == -1 {
zb = len(buffer)
}
// __e2a_l()
runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___E2A_L<<4,
[]uintptr{uintptr(unsafe.Pointer(&buffer[0])), uintptr(zb)})
return string(buffer[:zb]), nil
}
// __errno()
errno := int(*(*int32)(unsafe.Pointer(runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___ERRNO<<4,
[]uintptr{}))))
// __errno2()
errno2 := int(runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___ERRNO2<<4,
[]uintptr{}))
// strerror_r()
ret = runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS_STRERROR_R<<4,
[]uintptr{uintptr(errno), uintptr(unsafe.Pointer(&buffer[0])), 1024})
if ret == 0 {
zb := bytes.IndexByte(buffer[:], 0)
if zb == -1 {
zb = len(buffer)
}
return "", fmt.Errorf("%s (errno2=0x%x)", buffer[:zb], errno2)
} else {
return "", fmt.Errorf("fdToPath errno %d (errno2=0x%x)", errno, errno2)
}
}
func direntLeToDirentUnix(dirent *direntLE, dir uintptr, path string) (Dirent, error) {
var d Dirent
d.Ino = uint64(dirent.Ino)
offset, err := Telldir(dir)
if err != nil {
return d, err
}
d.Off = int64(offset)
s := string(bytes.Split(dirent.Name[:], []byte{0})[0])
copy(d.Name[:], s)
d.Reclen = uint16(24 + len(d.NameString()))
var st Stat_t
path = path + "/" + s
err = Lstat(path, &st)
if err != nil {
return d, err
}
d.Type = uint8(st.Mode >> 24)
return d, err
}
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulation of Getdirentries port from the Darwin implementation.
// COMMENTS FROM DARWIN:
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// Get path from fd to avoid unavailable call (fdopendir)
path, err := fdToPath(fd)
if err != nil {
return 0, err
}
d, err := Opendir(path)
if err != nil {
return 0, err
}
defer Closedir(d)
var cnt int64
for {
var entryLE direntLE
var entrypLE *direntLE
e := readdir_r(d, &entryLE, &entrypLE)
if e != nil {
return n, e
}
if entrypLE == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
// Dirent on zos has a different structure
entry, e := direntLeToDirentUnix(&entryLE, d, path)
if e != nil {
return n, e
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}
func ReadDirent(fd int, buf []byte) (n int, err error) {
var base = (*uintptr)(unsafe.Pointer(new(uint64)))
return Getdirentries(fd, buf, base)
}
func direntIno(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino))
}
func direntReclen(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen))
}
func direntNamlen(buf []byte) (uint64, bool) {
reclen, ok := direntReclen(buf)
if !ok {
return 0, false
}
return reclen - uint64(unsafe.Offsetof(Dirent{}.Name)), true
}

View File

@@ -7,11 +7,7 @@
package unix
import (
"unsafe"
"golang.org/x/sys/internal/unsafeheader"
)
import "unsafe"
// SysvShmAttach attaches the Sysv shared memory segment associated with the
// shared memory identifier id.
@@ -34,12 +30,7 @@ func SysvShmAttach(id int, addr uintptr, flag int) ([]byte, error) {
}
// Use unsafe to convert addr into a []byte.
// TODO: convert to unsafe.Slice once we can assume Go 1.17
var b []byte
hdr := (*unsafeheader.Slice)(unsafe.Pointer(&b))
hdr.Data = unsafe.Pointer(addr)
hdr.Cap = int(info.Segsz)
hdr.Len = int(info.Segsz)
b := unsafe.Slice((*byte)(unsafe.Pointer(addr)), int(info.Segsz))
return b, nil
}

Some files were not shown because too many files have changed in this diff Show More