mirror of
				https://github.com/mochi-mqtt/server.git
				synced 2025-10-31 19:42:38 +08:00 
			
		
		
		
	Compare commits
	
		
			102 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 0648e39507 | ||
|   | 233a82e448 | ||
|   | 51a8d8cb54 | ||
|   | 23c3208310 | ||
|   | 23e1092cda | ||
|   | d498576927 | ||
|   | 7e14ce99b5 | ||
|   | 4db49a4b9d | ||
|   | e60b8ff0c9 | ||
|   | b9d5dcb5f0 | ||
|   | 6d394d1fe9 | ||
|   | 1ee2158711 | ||
|   | af79b55b9f | ||
|   | e1a9497c25 | ||
|   | 62659e17ba | ||
|   | 7ad6dd8e1a | ||
|   | 565e07747e | ||
|   | 6acd775a6b | ||
|   | 493f6c8bb0 | ||
|   | d3785c2717 | ||
|   | 52a347169a | ||
|   | 797d75cb34 | ||
|   | 5225a357e5 | ||
|   | a734a0dc73 | ||
|   | 6704cf7227 | ||
|   | 9233e6fd39 | ||
|   | 1ca65d9631 | ||
|   | 33229da885 | ||
|   | c274d5fd08 | ||
|   | 10e82f41d6 | ||
|   | e6c07b2b78 | ||
|   | eed3ef9606 | ||
|   | 1ec880844d | ||
|   | 4b49652a8c | ||
|   | d46e7b5bcf | ||
|   | 17fb7dadbc | ||
|   | ed7fd836e1 | ||
|   | 605bb93c75 | ||
|   | c73ace2ea0 | ||
|   | aac6d699da | ||
|   | 7bd7bd5087 | ||
|   | 655bf9fdb1 | ||
|   | b188055c7d | ||
|   | aaf1d9d4c6 | ||
|   | 44ce819318 | ||
| ![dependabot[bot]](/assets/img/avatar_default.png)  | e4c76cc60c | ||
|   | da79faa972 | ||
|   | 46babc89c8 | ||
|   | 9b7a943888 | ||
|   | a909d30923 | ||
|   | 0851b09e4d | ||
|   | a302c9dd88 | ||
|   | 1e8f922102 | ||
|   | 4c16e5593f | ||
|   | 49cada4fbc | ||
|   | ef34510c0b | ||
|   | e5716caad1 | ||
|   | 4b039cb35c | ||
|   | aac245441a | ||
|   | bb54cc68e6 | ||
|   | 7ba1352a60 | ||
|   | ca849131eb | ||
|   | ba7e534122 | ||
|   | db760c34a5 | ||
|   | ae3ee81bb4 | ||
|   | c2ca02d149 | ||
|   | 77a64d9c87 | ||
|   | 8dec9cc962 | ||
|   | f90e52328d | ||
|   | 50aae47618 | ||
|   | 0d79f2d63b | ||
|   | 300152413c | ||
|   | 0de1d731db | ||
|   | 80746abc52 | ||
|   | a73cf4ca0e | ||
|   | bc549ee7ed | ||
|   | c464b46713 | ||
|   | 05ce56008c | ||
|   | 8254cb0cbc | ||
|   | 4ae58b79e3 | ||
|   | b895d688e0 | ||
|   | a600cd4ead | ||
|   | cdb44990cf | ||
|   | 2d9c128111 | ||
|   | a0d5bdb39f | ||
|   | 4ebcef3cb6 | ||
|   | fb8d4720d7 | ||
|   | 4080c89127 | ||
|   | 1b67e6f3f6 | ||
|   | 1adb02e087 | ||
|   | 4d4140aa99 | ||
|   | e31840a37d | ||
|   | 7d2e16f2d3 | ||
|   | 92cd935a16 | ||
|   | 25ce27ce2d | ||
|   | 527d084a4b | ||
|   | bb9f937bb0 | ||
|   | 511fe88684 | ||
|   | 75504ff201 | ||
|   | a556feb325 | ||
|   | d06f47f4b9 | ||
|   | 8d4cc091b4 | 
							
								
								
									
										42
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										42
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							| @@ -7,37 +7,39 @@ jobs: | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|     - uses: actions/checkout@v2 | ||||
|     - uses: actions/checkout@v3 | ||||
|  | ||||
|     - name: Set up Go | ||||
|       uses: actions/setup-go@v2 | ||||
|       uses: actions/setup-go@v3 | ||||
|       with: | ||||
|         go-version: 1.19 | ||||
|  | ||||
|     - name: Vet | ||||
|       run: go vet ./... | ||||
|      | ||||
|  | ||||
|     - name: Test | ||||
|       run: go test -race ./... && echo true | ||||
|        | ||||
|  | ||||
|   coverage: | ||||
|     name: Test with Coverage | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Install Go | ||||
|         if: success() | ||||
|         uses: actions/setup-go@v2 | ||||
|         with: | ||||
|           go-version: 1.19.x | ||||
|       - name: Checkout code | ||||
|         uses: actions/checkout@v2 | ||||
|       - name: Calc coverage | ||||
|         run: | | ||||
|           go test -v -covermode=count -coverprofile=coverage.out ./... | ||||
|       - name: Convert coverage.out to coverage.lcov | ||||
|         uses: jandelgado/gcov2lcov-action@v1.0.6 | ||||
|       - name: Coveralls | ||||
|         uses: coverallsapp/github-action@v1.1.2 | ||||
|         with: | ||||
|           github-token: ${{ secrets.github_token }} | ||||
|           path-to-lcov: coverage.lcov | ||||
|     - name: Set up Go | ||||
|       uses: actions/setup-go@v3 | ||||
|       with: | ||||
|         go-version: '1.19' | ||||
|     - name: Check out code | ||||
|       uses: actions/checkout@v3 | ||||
|     - name: Install dependencies | ||||
|       run: | | ||||
|         go mod download | ||||
|     - name: Run Unit tests | ||||
|       run: | | ||||
|         go test -race -covermode atomic -coverprofile=covprofile ./... | ||||
|     - name: Install goveralls | ||||
|       run: go install github.com/mattn/goveralls@latest | ||||
|     - name: Send coverage | ||||
|       env: | ||||
|         COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|       run: goveralls -coverprofile=covprofile -service=github | ||||
|   | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,3 +1,4 @@ | ||||
| cmd/mqtt | ||||
| .DS_Store | ||||
| *.db | ||||
| .idea | ||||
							
								
								
									
										165
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								README.md
									
									
									
									
									
								
							| @@ -2,7 +2,7 @@ | ||||
| <p align="center"> | ||||
|  | ||||
|   | ||||
| [](https://coveralls.io/github/mochi-co/mqtt?branch=master) | ||||
| [](https://coveralls.io/github/mochi-co/mqtt?branch=master) | ||||
| [](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2) | ||||
| [](https://pkg.go.dev/github.com/mochi-co/mqtt/v2) | ||||
| [](https://github.com/mochi-co/mqtt/issues) | ||||
| @@ -16,7 +16,7 @@ Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/m | ||||
| ### What is MQTT? | ||||
| MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol. | ||||
|  | ||||
| ## What's new in Version 2.0.0? | ||||
| ## What's new in Version 2? | ||||
| Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.  | ||||
|  | ||||
| Don't forget to use the new v2 import paths: | ||||
| @@ -37,14 +37,14 @@ import "github.com/mochi-co/mqtt/v2" | ||||
|     - Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.  | ||||
| - Developer-centric: | ||||
|     - Most core broker code is now exported and accessible, for total developer control. | ||||
|     - Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development. | ||||
|     - Full-featured and flexible Hook-based interfacing system to provide easy 'plugin' development. | ||||
|     - Direct Packet Injection using special inline client, or masquerade as existing clients. | ||||
| - Performant and Stable: | ||||
|     - Our classic trie-based Topic-Subscription model. | ||||
|     - A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.  | ||||
|     - Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour. | ||||
|     - Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3. | ||||
|     - Over a thousand carefully considered unit test scenarios. | ||||
| - TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners. | ||||
| - TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners. | ||||
| - Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own). | ||||
| - Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own). | ||||
|  | ||||
| @@ -83,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest | ||||
| Importing Mochi MQTT as a package requires just a few lines of code to get started. | ||||
| ``` go | ||||
| import ( | ||||
|   "log" | ||||
|  | ||||
|   "github.com/mochi-co/mqtt/v2" | ||||
|   "github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
|   "github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
|   // Create the new MQTT Server. | ||||
|   server := mqtt.New(nil) | ||||
|  | ||||
|    | ||||
|   // Allow all connections. | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
|   _ = server.AddHook(new(auth.AllowHook), nil) | ||||
|    | ||||
|   // Create a TCP listener on a standard port. | ||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr, nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|   tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
|   err := server.AddListener(tcp) | ||||
|   if err != nil { | ||||
|     log.Fatal(err) | ||||
|   } | ||||
|    | ||||
|   err = server.Serve() | ||||
|   if err != nil { | ||||
| @@ -112,10 +116,16 @@ Examples of running the broker with various configurations can be found in the [ | ||||
| #### Network Listeners | ||||
| The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are: | ||||
|  | ||||
| - `listeners.NewTCP(...)` - A TCP listener. | ||||
| - `listeners.NewWebsocket(...)` A Websocket listener. | ||||
| - `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard. | ||||
| - Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know! | ||||
| | Listener                     | Usage                                                                                        | | ||||
| |------------------------------|----------------------------------------------------------------------------------------------| | ||||
| | listeners.NewTCP             | A TCP listener                                                                               | | ||||
| | listeners.NewUnixSock        | A Unix Socket listener                                                                       | | ||||
| | listeners.NewNet             | A net.Listener listener                                                                      | | ||||
| | listeners.NewWebsocket       | A Websocket listener                                                                         | | ||||
| | listeners.NewHTTPStats       | An HTTP $SYS info dashboard                                                                  | | ||||
| | listeners.NewHTTPHealthCheck | An HTTP healthcheck listener to provide health check responses for e.g. cloud infrastructure | | ||||
|  | ||||
| > Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know! | ||||
|  | ||||
| A `*listeners.Config` may be passed to configure TLS.  | ||||
|  | ||||
| @@ -132,11 +142,13 @@ server := mqtt.New(&mqtt.Options{ | ||||
|       ObscureNotAuthorized: true, | ||||
|     }, | ||||
|   }, | ||||
|   ClientNetWriteBufferSize: 4096, | ||||
|   ClientNetReadBufferSize: 4096, | ||||
|   SysTopicResendInterval: 10, | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. | ||||
| Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs. | ||||
|  | ||||
|  | ||||
| ## Event Hooks  | ||||
| @@ -258,50 +270,50 @@ For more information on how the badger hook works, or how to use it, see the [ex | ||||
|  | ||||
| There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go). | ||||
|  | ||||
|  | ||||
|  | ||||
| ## Developing with Event Hooks | ||||
| Many hooks are available for interacting with the broker and client lifecycle.  | ||||
| The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go). | ||||
|  | ||||
| > The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets. | ||||
|  | ||||
| | Function | Usage |  | ||||
| | -------------------------- | -- | | ||||
| | OnStarted | Called when the server has successfully started.| | ||||
| | OnStopped | Called when the server has successfully stopped. |  | ||||
| | OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. | | ||||
| | OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. | | ||||
| | OnSysInfoTick | Called when the $SYS topic values are published out. | | ||||
| | OnConnect | Called when a new client connects |  | ||||
| | OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |  | ||||
| | OnDisconnect | Called when a client is disconnected for any reason. |  | ||||
| | OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |  | ||||
| | OnPacketRead | Called when a packet is received from a client. Allows packet modification. |  | ||||
| | OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |  | ||||
| | OnPacketSent | Called when a packet has been sent to a client. |  | ||||
| | OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |  | ||||
| | OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |  | ||||
| | OnSubscribed | Called when a client successfully subscribes to one or more filters. |  | ||||
| | OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|  | ||||
| | OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |  | ||||
| | OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |  | ||||
| | OnPublish | Called when a client publishes a message. Allows packet modification. |  | ||||
| | OnPublished | Called when a client has published a message to subscribers. |  | ||||
| | OnRetainMessage | Called then a published message is retained. |  | ||||
| | OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |  | ||||
| | OnQosComplete | Called when the Qos flow for a message has been completed. |  | ||||
| | OnQosDropped | Called when an inflight message expires before completion. |  | ||||
| | OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |  | ||||
| | OnWillSent | Called when an LWT message has been issued from a disconnecting client. |  | ||||
| | OnClientExpired | Called when a client session has expired and should be deleted. |  | ||||
| | OnRetainedExpired | Called when a retained message has expired and should be deleted. |  | ||||
| | OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|  | ||||
| | StoredClients |  Returns clients, eg. from a persistent store. |  | ||||
| | StoredSubscriptions |  Returns client subscriptions, eg. from a persistent store. |  | ||||
| | StoredInflightMessages | Returns inflight messages, eg. from a persistent store.  |  | ||||
| | StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |  | ||||
| | StoredSysInfo | Returns stored system info values, eg. from a persistent store. |  | ||||
| | Function               | Usage                                                                                                                                                                                                                                                                                                      |  | ||||
| |------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | ||||
| | OnStarted              | Called when the server has successfully started.                                                                                                                                                                                                                                                           | | ||||
| | OnStopped              | Called when the server has successfully stopped.                                                                                                                                                                                                                                                           |  | ||||
| | OnConnectAuthenticate  | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. | | ||||
| | OnACLCheck             | Called when a user attempts to publish or subscribe to a topic filter. As above.                                                                                                                                                                                                                           | | ||||
| | OnSysInfoTick          | Called when the $SYS topic values are published out.                                                                                                                                                                                                                                                       | | ||||
| | OnConnect              | Called when a new client connects, may return an error or packet code to halt the client connection process.                                                                                                                                                                                               |  | ||||
| | OnSessionEstablished   | Called when a new client successfully establishes a session (after OnConnect)                                                                                                                                                                                                                              |  | ||||
| | OnDisconnect           | Called when a client is disconnected for any reason.                                                                                                                                                                                                                                                       |  | ||||
| | OnAuthPacket           | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification.                                                                                                                                        |  | ||||
| | OnPacketRead           | Called when a packet is received from a client. Allows packet modification.                                                                                                                                                                                                                                |  | ||||
| | OnPacketEncode         | Called immediately before a packet is encoded to be sent to a client. Allows packet modification.                                                                                                                                                                                                          |  | ||||
| | OnPacketSent           | Called when a packet has been sent to a client.                                                                                                                                                                                                                                                            |  | ||||
| | OnPacketProcessed      | Called when a packet has been received and successfully handled by the broker.                                                                                                                                                                                                                             |  | ||||
| | OnSubscribe            | Called when a client subscribes to one or more filters. Allows packet modification.                                                                                                                                                                                                                        |  | ||||
| | OnSubscribed           | Called when a client successfully subscribes to one or more filters.                                                                                                                                                                                                                                       |  | ||||
| | OnSelectSubscribers    | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.                                                                                                                                                    |  | ||||
| | OnUnsubscribe          | Called when a client unsubscribes from one or more filters. Allows packet modification.                                                                                                                                                                                                                    |  | ||||
| | OnUnsubscribed         | Called when a client successfully unsubscribes from one or more filters.                                                                                                                                                                                                                                   |  | ||||
| | OnPublish              | Called when a client publishes a message. Allows packet modification.                                                                                                                                                                                                                                      |  | ||||
| | OnPublished            | Called when a client has published a message to subscribers.                                                                                                                                                                                                                                               |  | ||||
| | OnPublishDropped       | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond.                                                                                                                                                                                         |  | ||||
| | OnRetainMessage        | Called then a published message is retained.                                                                                                                                                                                                                                                               |  | ||||
| | OnRetainPublished      | Called then a retained message is published to a client.                                                                                                                                                                                                                                                   |  | ||||
| | OnQosPublish           | Called when a publish packet with Qos >= 1 is issued to a subscriber.                                                                                                                                                                                                                                      |  | ||||
| | OnQosComplete          | Called when the Qos flow for a message has been completed.                                                                                                                                                                                                                                                 |  | ||||
| | OnQosDropped           | Called when an inflight message expires before completion.                                                                                                                                                                                                                                                 |  | ||||
| | OnPacketIDExhausted    | Called when a client runs out of unused packet ids to assign.                                                                                                                                                                                                                                              |  | ||||
| | OnWill                 | Called when a client disconnects and intends to issue a will message. Allows packet modification.                                                                                                                                                                                                          |  | ||||
| | OnWillSent             | Called when an LWT message has been issued from a disconnecting client.                                                                                                                                                                                                                                    |  | ||||
| | OnClientExpired        | Called when a client session has expired and should be deleted.                                                                                                                                                                                                                                            |  | ||||
| | OnRetainedExpired      | Called when a retained message has expired and should be deleted.                                                                                                                                                                                                                                          |  | ||||
| | StoredClients          | Returns clients, eg. from a persistent store.                                                                                                                                                                                                                                                              |  | ||||
| | StoredSubscriptions    | Returns client subscriptions, eg. from a persistent store.                                                                                                                                                                                                                                                 |  | ||||
| | StoredInflightMessages | Returns inflight messages, eg. from a persistent store.                                                                                                                                                                                                                                                    |  | ||||
| | StoredRetainedMessages | Returns retained messages, eg. from a persistent store.                                                                                                                                                                                                                                                    |  | ||||
| | StoredSysInfo          | Returns stored system info values, eg. from a persistent store.                                                                                                                                                                                                                                            |  | ||||
|  | ||||
| If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`. | ||||
|  | ||||
| @@ -334,6 +346,8 @@ server.InjectPacket(cl, packets.Packet{ | ||||
|  | ||||
| See the [hooks example](examples/hooks/main.go) to see this feature in action. | ||||
|  | ||||
|  | ||||
|  | ||||
| ### Testing | ||||
| #### Unit Tests | ||||
| Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go: | ||||
| @@ -353,39 +367,54 @@ Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQ | ||||
| Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. | ||||
|  | ||||
| > The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers. | ||||
| > Benchmarks are provided as a general performance expectation guideline only. | ||||
| > Benchmarks are provided as a general performance expectation guideline only. Comparisons are performed using out-of-the-box default configurations. | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 | | ||||
| | Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 | | ||||
| | EMQX v5.0.11      | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 | | ||||
| | Mochi v2.2.10      | 124,772 | 125,456 | 124,614 | 314,461 | 313,186 | 311,910 | | ||||
| | [Mosquitto v2.0.15](https://github.com/eclipse/mosquitto) | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 | | ||||
| | [EMQX v5.0.11](https://github.com/emqx/emqx)      | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 | | ||||
| | [Rumqtt v0.21.0](https://github.com/bytebeamio/rumqtt) | 112,208 | 108,480 | 104,753 | 135,784 | 126,446 | 117,108 | | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 | | ||||
| | Mochi v2.2.10      | 41,825 | 31,663| 23,008 | 144,058 | 65,903 | 37,618 | | ||||
| | Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 | | ||||
| | EMQX v5.0.11      | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 | | ||||
| | EMQX v5.0.11      | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 | | ||||
| | Rumqtt v0.21.0    | 42,213 | 23,153 | 20,814 | 49,465 | 36,626 | 19,283 | | ||||
|  | ||||
| Million Message Challenge (hit the server with 1 million messages immediately): | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636  | | ||||
| | Mochi v2.2.10     | 13,532 | 4,425 | 2,344 | 52,120 | 7,274 | 2,701 | | ||||
| | Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 | | ||||
| | EMQX v5.0.11      | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 | | ||||
| | Rumqtt v0.21.0    | 78,972 | 5,047 | 3,804 | 4,286 | 3,249 | 2,027 | | ||||
|  | ||||
| > Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software. | ||||
|  | ||||
| ## Contribution Guidelines | ||||
| Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request. If you open a pull request, please try to follow the following guidelines: | ||||
| - Try to maintain test coverage where reasonably possible. | ||||
| - Clearly state what the PR does and why. | ||||
| - Remember to add your SPDX FileContributor tag to files where you have made a meaningful contribution. | ||||
|  | ||||
| [SPDX Annotations](https://spdx.dev) are used to clearly indicate the license, copyright, and contributions of each file in a machine-readable format. If you are adding a new file to the repository, please ensure it has the following SPDX header: | ||||
| ```go | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: Your name or alias <optional@email.address> | ||||
|  | ||||
| package name | ||||
| ``` | ||||
|  | ||||
| Please ensure to add a new `SPDX-FileContributor` line for each contributor to the file. Refer to other files for examples. Please remember to do this, your contributions to this project are valuable and appreciated - it's important to receive credit!  | ||||
|  | ||||
| ## Stargazers over time 🥰 | ||||
| [](https://starchart.cc/mochi-co/mqtt) | ||||
| Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues) | ||||
|  | ||||
| ## Contributions | ||||
| Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request. | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										160
									
								
								clients.go
									
									
									
									
									
								
							
							
						
						
									
										160
									
								
								clients.go
									
									
									
									
									
								
							| @@ -7,6 +7,7 @@ package mqtt | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| @@ -87,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client { | ||||
| 	defer cl.RUnlock() | ||||
| 	clients := make([]*Client, 0, cl.Len()) | ||||
| 	for _, client := range cl.internal { | ||||
| 		if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 { | ||||
| 		if client.Net.Listener == id && !client.Closed() { | ||||
| 			clients = append(clients, client) | ||||
| 		} | ||||
| 	} | ||||
| @@ -106,7 +107,7 @@ type Client struct { | ||||
|  | ||||
| // ClientConnection contains the connection transport and metadata for the client. | ||||
| type ClientConnection struct { | ||||
| 	conn     net.Conn          // the net.Conn used to establish the connection | ||||
| 	Conn     net.Conn          // the net.Conn used to establish the connection | ||||
| 	bconn    *bufio.ReadWriter // a buffered net.Conn for reading packets | ||||
| 	Remote   string            // the remote address of the client | ||||
| 	Listener string            // listener id of the client | ||||
| @@ -135,26 +136,35 @@ type Will struct { | ||||
|  | ||||
| // State tracks the state of the client. | ||||
| type ClientState struct { | ||||
| 	TopicAliases  TopicAliases   // a map of topic aliases | ||||
| 	stopCause     atomic.Value   // reason for stopping | ||||
| 	Inflight      *Inflight      // a map of in-flight qos messages | ||||
| 	Subscriptions *Subscriptions // a map of the subscription filters a client maintains | ||||
| 	disconnected  int64          // the time the client disconnected in unix time, for calculating expiry | ||||
| 	endOnce       sync.Once      // only end once | ||||
| 	packetID      uint32         // the current highest packetID | ||||
| 	done          uint32         // atomic counter which indicates that the client has closed | ||||
| 	keepalive     uint16         // the number of seconds the connection can wait | ||||
| 	TopicAliases    TopicAliases         // a map of topic aliases | ||||
| 	stopCause       atomic.Value         // reason for stopping | ||||
| 	Inflight        *Inflight            // a map of in-flight qos messages | ||||
| 	Subscriptions   *Subscriptions       // a map of the subscription filters a client maintains | ||||
| 	disconnected    int64                // the time the client disconnected in unix time, for calculating expiry | ||||
| 	outbound        chan *packets.Packet // queue for pending outbound packets | ||||
| 	endOnce         sync.Once            // only end once | ||||
| 	isTakenOver     uint32               // used to identify orphaned clients | ||||
| 	packetID        uint32               // the current highest packetID | ||||
| 	open            context.Context      // indicate that the client is open for packet exchange | ||||
| 	cancelOpen      context.CancelFunc   // cancel function for open context | ||||
| 	outboundQty     int32                // number of messages currently in the outbound queue | ||||
| 	Keepalive       uint16               // the number of seconds the connection can wait | ||||
| 	ServerKeepalive bool                 // keepalive was set by the server | ||||
| } | ||||
|  | ||||
| // newClient returns a new instance of Client. This is almost exclusively used by Server | ||||
| // for creating new clients, but it lives here because it's not dependent. | ||||
| func newClient(c net.Conn, o *ops) *Client { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	cl := &Client{ | ||||
| 		State: ClientState{ | ||||
| 			Inflight:      NewInflights(), | ||||
| 			Subscriptions: NewSubscriptions(), | ||||
| 			TopicAliases:  NewTopicAliases(o.capabilities.TopicAliasMaximum), | ||||
| 			keepalive:     defaultKeepalive, | ||||
| 			TopicAliases:  NewTopicAliases(o.options.Capabilities.TopicAliasMaximum), | ||||
| 			open:          ctx, | ||||
| 			cancelOpen:    cancel, | ||||
| 			Keepalive:     defaultKeepalive, | ||||
| 			outbound:      make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending), | ||||
| 		}, | ||||
| 		Properties: ClientProperties{ | ||||
| 			ProtocolVersion: defaultClientProtocolVersion, // default protocol version | ||||
| @@ -164,17 +174,33 @@ func newClient(c net.Conn, o *ops) *Client { | ||||
|  | ||||
| 	if c != nil { | ||||
| 		cl.Net = ClientConnection{ | ||||
| 			conn:   c, | ||||
| 			bconn:  bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), | ||||
| 			Conn: c, | ||||
| 			bconn: bufio.NewReadWriter( | ||||
| 				bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize), | ||||
| 				bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize), | ||||
| 			), | ||||
| 			Remote: c.RemoteAddr().String(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.State.keepalive) | ||||
|  | ||||
| 	return cl | ||||
| } | ||||
|  | ||||
| // WriteLoop ranges over pending outbound messages and writes them to the client connection. | ||||
| func (cl *Client) WriteLoop() { | ||||
| 	for { | ||||
| 		select { | ||||
| 		case pk := <-cl.State.outbound: | ||||
| 			if err := cl.WritePacket(*pk); err != nil { | ||||
| 				cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet") | ||||
| 			} | ||||
| 			atomic.AddInt32(&cl.State.outboundQty, -1) | ||||
| 		case <-cl.State.open.Done(): | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ParseConnect parses the connect parameters and properties for a client. | ||||
| func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||
| 	cl.Net.Listener = lid | ||||
| @@ -184,9 +210,9 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||
| 	cl.Properties.Clean = pk.Connect.Clean | ||||
| 	cl.Properties.Props = pk.Properties.Copy(false) | ||||
|  | ||||
| 	cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client | ||||
| 	cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))    // client receive max | ||||
|  | ||||
| 	cl.State.Keepalive = pk.Connect.Keepalive                                              // [MQTT-3.2.2-22] | ||||
| 	cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client | ||||
| 	cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))            // client receive max | ||||
| 	cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum) | ||||
|  | ||||
| 	cl.ID = pk.Connect.ClientIdentifier | ||||
| @@ -195,11 +221,6 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||
| 		cl.Properties.Props.AssignedClientID = cl.ID | ||||
| 	} | ||||
|  | ||||
| 	cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive | ||||
| 	if pk.Connect.Keepalive > 0 { | ||||
| 		cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Connect.WillFlag { | ||||
| 		cl.Properties.Will = Will{ | ||||
| 			Qos:               pk.Connect.WillQos, | ||||
| @@ -217,19 +238,17 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||
| 			cl.Properties.Will.Flag = 1 // atomic for checking | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.State.keepalive) | ||||
| } | ||||
|  | ||||
| // refreshDeadline refreshes the read/write deadline for the net.Conn connection. | ||||
| func (cl *Client) refreshDeadline(keepalive uint16) { | ||||
| 	if cl.Net.conn != nil { | ||||
| 		var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 | ||||
| 		if keepalive > 0 { | ||||
| 			expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] | ||||
| 		} | ||||
| 	var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 | ||||
| 	if keepalive > 0 { | ||||
| 		expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] | ||||
| 	} | ||||
|  | ||||
| 		_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22] | ||||
| 	if cl.Net.Conn != nil { | ||||
| 		_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22] | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -237,28 +256,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) { | ||||
| // If no unused packet ids are available, an error is returned and the client | ||||
| // should be disconnected. | ||||
| func (cl *Client) NextPacketID() (i uint32, err error) { | ||||
| 	cl.Lock() | ||||
| 	defer cl.Unlock() | ||||
|  | ||||
| 	i = atomic.LoadUint32(&cl.State.packetID) | ||||
| 	started := i + 1 | ||||
| 	started := i | ||||
| 	overflowed := false | ||||
| 	for { | ||||
| 		if i >= 65535 { | ||||
| 			overflowed = true | ||||
| 			i = 1 | ||||
| 		} else { | ||||
| 			i++ | ||||
| 		} | ||||
|  | ||||
| 		if overflowed && i == started { | ||||
| 			return 0, packets.ErrQuotaExceeded | ||||
| 		} | ||||
|  | ||||
| 		if i >= cl.ops.options.Capabilities.maximumPacketID { | ||||
| 			overflowed = true | ||||
| 			i = 0 | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		i++ | ||||
|  | ||||
| 		if _, ok := cl.State.Inflight.Get(uint16(i)); !ok { | ||||
| 			break | ||||
| 			atomic.StoreUint32(&cl.State.packetID, i) | ||||
| 			return i, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	atomic.StoreUint32(&cl.State.packetID, i) | ||||
| 	return i, nil | ||||
| } | ||||
|  | ||||
| // ResendInflightMessages attempts to resend any pending inflight messages to connected clients. | ||||
| @@ -272,7 +293,7 @@ func (cl *Client) ResendInflightMessages(force bool) error { | ||||
| 			tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3] | ||||
| 		} | ||||
|  | ||||
| 		//	cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends) | ||||
| 		cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0) | ||||
| 		err := cl.WritePacket(tk) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -290,17 +311,18 @@ func (cl *Client) ResendInflightMessages(force bool) error { | ||||
| } | ||||
|  | ||||
| // ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. | ||||
| func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 { | ||||
| 	var deleted int64 | ||||
| func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 { | ||||
| 	deleted := []uint16{} | ||||
| 	for _, tk := range cl.State.Inflight.GetAll(false) { | ||||
| 		if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now { | ||||
| 			if ok := cl.State.Inflight.Delete(tk.PacketID); ok { | ||||
| 				cl.ops.hooks.OnQosDropped(cl, tk) | ||||
| 				atomic.AddInt64(&cl.ops.info.Inflight, -1) | ||||
| 				deleted++ | ||||
| 				deleted = append(deleted, tk.PacketID) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return deleted | ||||
| } | ||||
|  | ||||
| @@ -310,11 +332,11 @@ func (cl *Client) Read(packetHandler ReadFn) error { | ||||
| 	var err error | ||||
|  | ||||
| 	for { | ||||
| 		if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 		if cl.Closed() { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		cl.refreshDeadline(cl.State.keepalive) | ||||
| 		cl.refreshDeadline(cl.State.Keepalive) | ||||
| 		fh := new(packets.FixedHeader) | ||||
| 		err = cl.ReadFixedHeader(fh) | ||||
| 		if err != nil { | ||||
| @@ -335,20 +357,20 @@ func (cl *Client) Read(packetHandler ReadFn) error { | ||||
|  | ||||
| // Stop instructs the client to shut down all processing goroutines and disconnect. | ||||
| func (cl *Client) Stop(err error) { | ||||
| 	if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	cl.State.endOnce.Do(func() { | ||||
| 		if cl.Net.conn != nil { | ||||
| 			_ = cl.Net.conn.Close() // omit close error | ||||
|  | ||||
| 		if cl.Net.Conn != nil { | ||||
| 			_ = cl.Net.Conn.Close() // omit close error | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			cl.State.stopCause.Store(err) | ||||
| 		} | ||||
|  | ||||
| 		atomic.StoreUint32(&cl.State.done, 1) | ||||
| 		if cl.State.cancelOpen != nil { | ||||
| 			cl.State.cancelOpen() | ||||
| 		} | ||||
|  | ||||
| 		atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix()) | ||||
| 	}) | ||||
| } | ||||
| @@ -361,6 +383,11 @@ func (cl *Client) StopCause() error { | ||||
| 	return cl.State.stopCause.Load().(error) | ||||
| } | ||||
|  | ||||
| // Closed returns true if client connection is closed. | ||||
| func (cl *Client) Closed() bool { | ||||
| 	return cl.State.open == nil || cl.State.open.Err() != nil | ||||
| } | ||||
|  | ||||
| // ReadFixedHeader reads in the values of the next packet's fixed header. | ||||
| func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | ||||
| 	if cl.Net.bconn == nil { | ||||
| @@ -383,7 +410,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize { | ||||
| 	if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize { | ||||
| 		return packets.ErrPacketTooLarge // [MQTT-3.2.2-15] | ||||
| 	} | ||||
|  | ||||
| @@ -454,15 +481,14 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er | ||||
|  | ||||
| // WritePacket encodes and writes a packet to the client. | ||||
| func (cl *Client) WritePacket(pk packets.Packet) error { | ||||
| 	if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 	if cl.Closed() { | ||||
| 		return ErrConnectionClosed | ||||
| 	} | ||||
|  | ||||
| 	if cl.Net.conn == nil { | ||||
| 	if cl.Net.Conn == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	defer cl.refreshDeadline(cl.State.keepalive) | ||||
| 	if pk.Expiry > 0 { | ||||
| 		pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] | ||||
| 	} | ||||
| @@ -476,8 +502,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | ||||
| 		pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set | ||||
| 	} | ||||
|  | ||||
| 	if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo { | ||||
| 		pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode | ||||
| 	if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo { | ||||
| 		pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode | ||||
| 	} | ||||
|  | ||||
| 	pk = cl.ops.hooks.OnPacketEncode(cl, pk) | ||||
| @@ -527,7 +553,11 @@ func (cl *Client) WritePacket(pk packets.Packet) error { | ||||
| 	} | ||||
|  | ||||
| 	nb := net.Buffers{buf.Bytes()} | ||||
| 	n, err := nb.WriteTo(cl.Net.conn) | ||||
| 	n, err := func() (int64, error) { | ||||
| 		cl.Lock() | ||||
| 		defer cl.Unlock() | ||||
| 		return nb.WriteTo(cl.Net.Conn) | ||||
| 	}() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| @@ -29,9 +30,13 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) { | ||||
| 		info:  new(system.Info), | ||||
| 		hooks: new(Hooks), | ||||
| 		log:   &logger, | ||||
| 		capabilities: &Capabilities{ | ||||
| 			ReceiveMaximum:    10, | ||||
| 			TopicAliasMaximum: 10000, | ||||
| 		options: &Options{ | ||||
| 			Capabilities: &Capabilities{ | ||||
| 				ReceiveMaximum:             10, | ||||
| 				TopicAliasMaximum:          10000, | ||||
| 				MaximumClientWritesPending: 3, | ||||
| 				maximumPacketID:            10, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| @@ -42,6 +47,9 @@ func newTestClient() (cl *Client, r net.Conn, w net.Conn) { | ||||
| 	cl.State.Inflight.receiveQuota = 10 | ||||
| 	cl.Properties.Props.TopicAliasMaximum = 0 | ||||
| 	cl.Properties.Props.RequestResponseInfo = 0x1 | ||||
|  | ||||
| 	go cl.WriteLoop() | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @@ -107,8 +115,8 @@ func TestClientsDelete(t *testing.T) { | ||||
|  | ||||
| func TestClientsGetByListener(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}}) | ||||
| 	cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}}) | ||||
| 	cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}}) | ||||
| 	cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| 	require.Contains(t, cl.internal, "t2") | ||||
|  | ||||
| @@ -125,10 +133,12 @@ func TestNewClient(t *testing.T) { | ||||
| 	require.NotNil(t, cl.State.Inflight.internal) | ||||
| 	require.NotNil(t, cl.State.Subscriptions) | ||||
| 	require.NotNil(t, cl.State.TopicAliases) | ||||
| 	require.Equal(t, defaultKeepalive, cl.State.keepalive) | ||||
| 	require.Equal(t, defaultKeepalive, cl.State.Keepalive) | ||||
| 	require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) | ||||
| 	require.NotNil(t, cl.Net.conn) | ||||
| 	require.NotNil(t, cl.Net.Conn) | ||||
| 	require.NotNil(t, cl.Net.bconn) | ||||
| 	require.NotNil(t, cl.ops) | ||||
| 	require.NotNil(t, cl.ops.options.Capabilities) | ||||
| 	require.False(t, cl.Net.Inline) | ||||
| } | ||||
|  | ||||
| @@ -155,7 +165,7 @@ func TestClientParseConnect(t *testing.T) { | ||||
|  | ||||
| 	cl.ParseConnect("tcp1", pk) | ||||
| 	require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) | ||||
| 	require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive) | ||||
| 	require.Equal(t, pk.Connect.Keepalive, cl.State.Keepalive) | ||||
| 	require.Equal(t, pk.Connect.Clean, cl.Properties.Clean) | ||||
| 	require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) | ||||
| 	require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName) | ||||
| @@ -163,8 +173,8 @@ func TestClientParseConnect(t *testing.T) { | ||||
| 	require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos) | ||||
| 	require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain) | ||||
| 	require.Equal(t, uint32(1), cl.Properties.Will.Flag) | ||||
| 	require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota) | ||||
| 	require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota) | ||||
| 	require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota) | ||||
| 	require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota) | ||||
| 	require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota) | ||||
| 	require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota) | ||||
| } | ||||
| @@ -237,28 +247,32 @@ func TestClientNextPacketIDInUse(t *testing.T) { | ||||
|  | ||||
| func TestClientNextPacketIDExhausted(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	for i := 0; i <= 65535; i++ { | ||||
| 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) | ||||
| 	for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { | ||||
| 		cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)} | ||||
| 	} | ||||
|  | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.Equal(t, uint32(0), i) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrQuotaExceeded) | ||||
| 	require.Equal(t, uint32(0), i) | ||||
| } | ||||
|  | ||||
| func TestClientNextPacketIDOverflow(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ { | ||||
| 		cl.State.Inflight.internal[uint16(i)] = packets.Packet{} | ||||
| 	} | ||||
|  | ||||
| 	cl.State.packetID = uint32(65534) | ||||
|  | ||||
| 	cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1) | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(65535), i) | ||||
| 	require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i) | ||||
| 	cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{} | ||||
|  | ||||
| 	i, err = cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(1), i) | ||||
| 	cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID | ||||
| 	_, err = cl.NextPacketID() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrQuotaExceeded) | ||||
| } | ||||
|  | ||||
| func TestClientClearInflights(t *testing.T) { | ||||
| @@ -272,7 +286,9 @@ func TestClientClearInflights(t *testing.T) { | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) | ||||
| 	require.Equal(t, 5, cl.State.Inflight.Len()) | ||||
|  | ||||
| 	cl.ClearInflights(n, 4) | ||||
| 	deleted := cl.ClearInflights(n, 4) | ||||
| 	require.Len(t, deleted, 3) | ||||
| 	require.ElementsMatch(t, []uint16{1, 2, 5}, deleted) | ||||
| 	require.Equal(t, 2, cl.State.Inflight.Len()) | ||||
| } | ||||
|  | ||||
| @@ -318,7 +334,7 @@ func TestClientResendInflightMessagesNoMessages(t *testing.T) { | ||||
| func TestClientRefreshDeadline(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.refreshDeadline(10) | ||||
| 	require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline? | ||||
| 	require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline? | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeader(t *testing.T) { | ||||
| @@ -352,7 +368,7 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) { | ||||
|  | ||||
| func TestClientReadFixedHeaderPacketOversized(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	cl.ops.capabilities.MaximumPacketSize = 2 | ||||
| 	cl.ops.options.Capabilities.MaximumPacketSize = 2 | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	go func() { | ||||
| @@ -451,7 +467,7 @@ func TestClientReadOK(t *testing.T) { | ||||
| func TestClientReadDone(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	cl.State.done = 1 | ||||
| 	cl.State.cancelOpen() | ||||
|  | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| @@ -468,10 +484,17 @@ func TestClientStop(t *testing.T) { | ||||
| 	cl.Stop(nil) | ||||
| 	require.Equal(t, nil, cl.State.stopCause.Load()) | ||||
| 	require.Equal(t, time.Now().Unix(), cl.State.disconnected) | ||||
| 	require.Equal(t, uint32(1), cl.State.done) | ||||
| 	require.True(t, cl.Closed()) | ||||
| 	require.Equal(t, nil, cl.StopCause()) | ||||
| } | ||||
|  | ||||
| func TestClientClosed(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	require.False(t, cl.Closed()) | ||||
| 	cl.Stop(nil) | ||||
| 	require.True(t, cl.Closed()) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderError(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| @@ -577,7 +600,7 @@ func TestClientReadPacket(t *testing.T) { | ||||
|  | ||||
| func TestClientReadPacketInvalidTypeError(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Net.conn.Close() | ||||
| 	cl.Net.Conn.Close() | ||||
| 	_, err := cl.ReadPacket(&packets.FixedHeader{}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Contains(t, err.Error(), "invalid packet type") | ||||
| @@ -601,7 +624,7 @@ func TestClientWritePacket(t *testing.T) { | ||||
| 		require.NoError(t, err, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| 		time.Sleep(2 * time.Millisecond) | ||||
| 		cl.Net.conn.Close() | ||||
| 		cl.Net.Conn.Close() | ||||
|  | ||||
| 		require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| @@ -683,7 +706,7 @@ func TestClientWritePacketWriteNoConn(t *testing.T) { | ||||
|  | ||||
| func TestClientWritePacketWriteError(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Net.conn.Close() | ||||
| 	cl.Net.Conn.Close() | ||||
|  | ||||
| 	err := cl.WritePacket(*pkTable[1].Packet) | ||||
| 	require.Error(t, err) | ||||
|   | ||||
							
								
								
									
										52
									
								
								examples/benchmark/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								examples/benchmark/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener") | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr, nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
| @@ -30,14 +30,14 @@ func main() { | ||||
| 	l := server.Log.Level(zerolog.DebugLevel) | ||||
| 	server.Log = &l | ||||
|  | ||||
| 	err := server.AddHook(new(auth.AllowHook), nil) | ||||
| 	err := server.AddHook(new(debug.Hook), &debug.Options{ | ||||
| 		// ShowPacketData: true, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	err = server.AddHook(new(debug.Hook), &debug.Options{ | ||||
| 		ShowPacketData: true, | ||||
| 	}) | ||||
| 	err = server.AddHook(new(auth.AllowHook), nil) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|   | ||||
| @@ -110,8 +110,9 @@ func (h *ExampleHook) Init(config any) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) { | ||||
| func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { | ||||
| 	h.Log.Info().Str("client", cl.ID).Msgf("client connected") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { | ||||
|   | ||||
| @@ -26,10 +26,8 @@ func main() { | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	server.Options.Capabilities.ServerKeepAlive = 60 | ||||
| 	server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true | ||||
| 	server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true | ||||
| 	server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true | ||||
|  | ||||
| 	_ = server.AddHook(new(pahoAuthHook), nil) | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| @@ -62,6 +60,7 @@ func (h *pahoAuthHook) ID() string { | ||||
| func (h *pahoAuthHook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnConnectAuthenticate, | ||||
| 		mqtt.OnConnect, | ||||
| 		mqtt.OnACLCheck, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
| @@ -73,3 +72,12 @@ func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) | ||||
| func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { | ||||
| 	return topic != "test/nosubscribe" | ||||
| } | ||||
|  | ||||
| func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { | ||||
| 	// Handle paho test_server_keep_alive | ||||
| 	if pk.Connect.Keepalive == 120 && pk.Connect.Clean { | ||||
| 		cl.State.Keepalive = 60 | ||||
| 		cl.State.ServerKeepalive = true | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -30,12 +30,15 @@ func main() { | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	err := server.AddHook(new(bolt.Hook), bolt.Options{ | ||||
| 	err := server.AddHook(new(bolt.Hook), &bolt.Options{ | ||||
| 		Path: "bolt.db", | ||||
| 		Options: &bbolt.Options{ | ||||
| 			Timeout: 500 * time.Millisecond, | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
|   | ||||
| @@ -97,7 +97,7 @@ func main() { | ||||
|  | ||||
| 	stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{ | ||||
| 		TLSConfig: tlsConfig, | ||||
| 	}, nil) | ||||
| 	}, server.Info) | ||||
| 	err = server.AddListener(stats) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
|   | ||||
							
								
								
									
										101
									
								
								fanpool.go
									
									
									
									
									
								
							
							
						
						
									
										101
									
								
								fanpool.go
									
									
									
									
									
								
							| @@ -1,101 +0,0 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co, chowyu08, muXxer | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	xh "github.com/cespare/xxhash/v2" | ||||
| ) | ||||
|  | ||||
| // taskChan is a channel for incoming task functions. | ||||
| type taskChan chan func() | ||||
|  | ||||
| // FanPool is a fixed-sized fan-style worker pool with multiple | ||||
| // working 'columns'. Instead of a single queue channel processed by | ||||
| // many goroutines, this fan pool uses many queue channels each | ||||
| // processed by a single goroutine. | ||||
| // Very special thanks are given to the authors of HMQ in particular | ||||
| // @chowyu08 and @muXxer for their work on the fixpool worker pool | ||||
| // https://github.com/fhmq/hmq/blob/master/pool/fixpool.go | ||||
| // from which this fan-pool is heavily inspired. | ||||
| type FanPool struct { | ||||
| 	queue    []taskChan | ||||
| 	wg       sync.WaitGroup | ||||
| 	capacity uint64 | ||||
| 	perChan  uint64 | ||||
| 	Mutex    sync.Mutex | ||||
| } | ||||
|  | ||||
| // New returns a new instance of FanPool. fanSize controls the number of 'columns' | ||||
| // of the fan, whereas queueSize controls the size of each column's queue. | ||||
| func NewFanPool(fanSize, queueSize uint64) *FanPool { | ||||
| 	pool := &FanPool{ | ||||
| 		capacity: fanSize, | ||||
| 		perChan:  queueSize, | ||||
| 		queue:    make([]taskChan, fanSize), | ||||
| 	} | ||||
|  | ||||
| 	pool.fillWorkers(fanSize) | ||||
|  | ||||
| 	return pool | ||||
| } | ||||
|  | ||||
| // fillWorkers adds columns to the fan pool with an associated worker goroutine. | ||||
| func (p *FanPool) fillWorkers(n uint64) { | ||||
| 	for i := uint64(0); i < n; i++ { | ||||
| 		p.queue[i] = make(taskChan, p.perChan) | ||||
| 		go p.worker(p.queue[i]) | ||||
| 		p.wg.Add(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // worker is a worker goroutine which processes tasks from a single queue. | ||||
| func (p *FanPool) worker(ch taskChan) { | ||||
| 	defer p.wg.Done() | ||||
| 	var task func() | ||||
| 	var ok bool | ||||
| 	for { | ||||
| 		task, ok = <-ch | ||||
| 		if !ok { | ||||
| 			return | ||||
| 		} | ||||
| 		task() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Enqueue adds a new task to the queue to be processed. | ||||
| func (p *FanPool) Enqueue(id string, task func()) { | ||||
| 	if p.Size() == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// We can use xh.Sum64 to get a specific queue index | ||||
| 	// which remains the same for a client id, giving each | ||||
| 	// client their own queue. | ||||
| 	p.queue[xh.Sum64([]byte(id))%p.Size()] <- task | ||||
| } | ||||
|  | ||||
| // Wait blocks until all the workers in the pool have completed. | ||||
| func (p *FanPool) Wait() { | ||||
| 	p.wg.Wait() | ||||
| } | ||||
|  | ||||
| // Close issues a shutdown signal to the workers. | ||||
| func (p *FanPool) Close() { | ||||
| 	for i := 0; i < int(p.Size()); i++ { | ||||
| 		if p.queue[i] != nil { | ||||
| 			close(p.queue[i]) | ||||
| 		} | ||||
| 	} | ||||
| 	p.queue = nil | ||||
| 	atomic.StoreUint64(&p.capacity, 0) | ||||
| } | ||||
|  | ||||
| // Size returns the current number of workers in the pool. | ||||
| func (p *FanPool) Size() uint64 { | ||||
| 	return atomic.LoadUint64(&p.capacity) | ||||
| } | ||||
| @@ -1,89 +0,0 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestFanPool(t *testing.T) { | ||||
| 	f := NewFanPool(1, 2) | ||||
| 	require.NotNil(t, f) | ||||
| 	require.Equal(t, uint64(1), f.capacity) | ||||
| 	require.Equal(t, 2, cap(f.queue[0])) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func() { | ||||
| 		f.Enqueue("test", func() { | ||||
| 			o <- true | ||||
| 		}) | ||||
| 	}() | ||||
|  | ||||
| 	require.True(t, <-o) | ||||
| 	f.Close() | ||||
| 	f.Wait() | ||||
| } | ||||
|  | ||||
| func TestFillWorkers(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		perChan: 3, | ||||
| 		queue:   make([]taskChan, 2), | ||||
| 	} | ||||
| 	f.fillWorkers(2) | ||||
| 	require.Len(t, f.queue, 2) | ||||
| 	require.Equal(t, 3, cap(f.queue[0])) | ||||
| } | ||||
|  | ||||
| func TestEnqueue(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 2, | ||||
| 		queue: []taskChan{ | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		f.Enqueue("a", func() {}) | ||||
| 	}() | ||||
| 	require.NotNil(t, <-f.queue[1]) | ||||
| } | ||||
|  | ||||
| func TestEnqueueOnEmpty(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		queue: []taskChan{}, | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		f.Enqueue("a", func() {}) | ||||
| 	}() | ||||
|  | ||||
| 	require.Len(t, f.queue, 0) | ||||
| } | ||||
|  | ||||
| func TestSize(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 10, | ||||
| 	} | ||||
|  | ||||
| 	require.Equal(t, uint64(10), f.Size()) | ||||
| } | ||||
|  | ||||
| func TestClose(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 3, | ||||
| 		queue: []taskChan{ | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	f.Close() | ||||
| 	require.Equal(t, uint64(0), f.Size()) | ||||
| 	require.Nil(t, f.queue) | ||||
| } | ||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @@ -6,7 +6,6 @@ require ( | ||||
| 	github.com/alicebob/miniredis/v2 v2.23.0 | ||||
| 	github.com/asdine/storm v2.1.2+incompatible | ||||
| 	github.com/asdine/storm/v3 v3.2.1 | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/jinzhu/copier v0.3.5 | ||||
| @@ -21,6 +20,7 @@ require ( | ||||
| require ( | ||||
| 	github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect | ||||
| 	github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/dgraph-io/badger v1.6.0 // indirect | ||||
| 	github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect | ||||
| @@ -33,8 +33,8 @@ require ( | ||||
| 	github.com/pkg/errors v0.9.1 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect | ||||
| 	golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect | ||||
| 	golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect | ||||
| 	golang.org/x/net v0.7.0 // indirect | ||||
| 	golang.org/x/sys v0.5.0 // indirect | ||||
| 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect | ||||
| 	google.golang.org/protobuf v1.28.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										10
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.sum
									
									
									
									
									
								
							| @@ -109,8 +109,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk | ||||
| golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= | ||||
| golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ= | ||||
| golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= | ||||
| golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= | ||||
| golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= | ||||
| golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| @@ -118,11 +118,11 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7w | ||||
| golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= | ||||
| golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= | ||||
| golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= | ||||
| golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
|   | ||||
							
								
								
									
										192
									
								
								hooks.go
									
									
									
									
									
								
							
							
						
						
									
										192
									
								
								hooks.go
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| // SPDX-FileContributor: mochi-co, thedevop | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| @@ -39,15 +39,17 @@ const ( | ||||
| 	OnUnsubscribed | ||||
| 	OnPublish | ||||
| 	OnPublished | ||||
| 	OnPublishDropped | ||||
| 	OnRetainMessage | ||||
| 	OnRetainPublished | ||||
| 	OnQosPublish | ||||
| 	OnQosComplete | ||||
| 	OnQosDropped | ||||
| 	OnPacketIDExhausted | ||||
| 	OnWill | ||||
| 	OnWillSent | ||||
| 	OnClientExpired | ||||
| 	OnRetainedExpired | ||||
| 	OnExpireInflights | ||||
| 	StoredClients | ||||
| 	StoredSubscriptions | ||||
| 	StoredInflightMessages | ||||
| @@ -73,7 +75,7 @@ type Hook interface { | ||||
| 	OnConnectAuthenticate(cl *Client, pk packets.Packet) bool | ||||
| 	OnACLCheck(cl *Client, topic string, write bool) bool | ||||
| 	OnSysInfoTick(*system.Info) | ||||
| 	OnConnect(cl *Client, pk packets.Packet) | ||||
| 	OnConnect(cl *Client, pk packets.Packet) error | ||||
| 	OnSessionEstablished(cl *Client, pk packets.Packet) | ||||
| 	OnDisconnect(cl *Client, err error, expire bool) | ||||
| 	OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) | ||||
| @@ -88,15 +90,17 @@ type Hook interface { | ||||
| 	OnUnsubscribed(cl *Client, pk packets.Packet) | ||||
| 	OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) | ||||
| 	OnPublished(cl *Client, pk packets.Packet) | ||||
| 	OnPublishDropped(cl *Client, pk packets.Packet) | ||||
| 	OnRetainMessage(cl *Client, pk packets.Packet, r int64) | ||||
| 	OnRetainPublished(cl *Client, pk packets.Packet) | ||||
| 	OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) | ||||
| 	OnQosComplete(cl *Client, pk packets.Packet) | ||||
| 	OnQosDropped(cl *Client, pk packets.Packet) | ||||
| 	OnPacketIDExhausted(cl *Client, pk packets.Packet) | ||||
| 	OnWill(cl *Client, will Will) (Will, error) | ||||
| 	OnWillSent(cl *Client, pk packets.Packet) | ||||
| 	OnClientExpired(cl *Client) | ||||
| 	OnRetainedExpired(filter string) | ||||
| 	OnExpireInflights(cl *Client, expiry int64) | ||||
| 	StoredClients() ([]storage.Client, error) | ||||
| 	StoredSubscriptions() ([]storage.Subscription, error) | ||||
| 	StoredInflightMessages() ([]storage.Message, error) | ||||
| @@ -112,10 +116,10 @@ type HookOptions struct { | ||||
| // Hooks is a slice of Hook interfaces to be called in sequence. | ||||
| type Hooks struct { | ||||
| 	Log        *zerolog.Logger // a logger for the hook (from the server) | ||||
| 	internal   []Hook          // a slice of hooks | ||||
| 	internal   atomic.Value    // a slice of []Hook | ||||
| 	wg         sync.WaitGroup  // a waitgroup for syncing hook shutdown | ||||
| 	qty        int64           // the number of hooks in use | ||||
| 	sync.Mutex                 // a mutex | ||||
| 	sync.Mutex                 // a mutex for locking when adding hooks | ||||
| } | ||||
|  | ||||
| // Len returns the number of hooks added. | ||||
| @@ -125,7 +129,7 @@ func (h *Hooks) Len() int64 { | ||||
|  | ||||
| // Provides returns true if any one hook provides any of the requested hook methods. | ||||
| func (h *Hooks) Provides(b ...byte) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		for _, hb := range b { | ||||
| 			if hook.Provides(hb) { | ||||
| 				return true | ||||
| @@ -140,26 +144,39 @@ func (h *Hooks) Provides(b ...byte) bool { | ||||
| func (h *Hooks) Add(hook Hook, config any) error { | ||||
| 	h.Lock() | ||||
| 	defer h.Unlock() | ||||
| 	if h.internal == nil { | ||||
| 		h.internal = []Hook{} | ||||
| 	} | ||||
|  | ||||
| 	err := hook.Init(config) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) | ||||
| 	} | ||||
|  | ||||
| 	h.internal = append(h.internal, hook) | ||||
| 	i, ok := h.internal.Load().([]Hook) | ||||
| 	if !ok { | ||||
| 		i = []Hook{} | ||||
| 	} | ||||
|  | ||||
| 	i = append(i, hook) | ||||
| 	h.internal.Store(i) | ||||
| 	atomic.AddInt64(&h.qty, 1) | ||||
| 	h.wg.Add(1) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // GetAll returns a slice of all the hooks. | ||||
| func (h *Hooks) GetAll() []Hook { | ||||
| 	i, ok := h.internal.Load().([]Hook) | ||||
| 	if !ok { | ||||
| 		return []Hook{} | ||||
| 	} | ||||
|  | ||||
| 	return i | ||||
| } | ||||
|  | ||||
| // Stop indicates all attached hooks to gracefully end. | ||||
| func (h *Hooks) Stop() { | ||||
| 	go func() { | ||||
| 		for _, hook := range h.internal { | ||||
| 		for _, hook := range h.GetAll() { | ||||
| 			h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") | ||||
| 			if err := hook.Stop(); err != nil { | ||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") | ||||
| @@ -174,7 +191,7 @@ func (h *Hooks) Stop() { | ||||
|  | ||||
| // OnSysInfoTick is called when the $SYS topic values are published out. | ||||
| func (h *Hooks) OnSysInfoTick(sys *system.Info) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnSysInfoTick) { | ||||
| 			hook.OnSysInfoTick(sys) | ||||
| 		} | ||||
| @@ -183,7 +200,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) { | ||||
|  | ||||
| // OnStarted is called when the server has successfully started. | ||||
| func (h *Hooks) OnStarted() { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnStarted) { | ||||
| 			hook.OnStarted() | ||||
| 		} | ||||
| @@ -192,25 +209,29 @@ func (h *Hooks) OnStarted() { | ||||
|  | ||||
| // OnStopped is called when the server has successfully stopped. | ||||
| func (h *Hooks) OnStopped() { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnStopped) { | ||||
| 			hook.OnStopped() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnConnect is called when a new client connects. | ||||
| func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| // OnConnect is called when a new client connects, and may return a packets.Code as an error to halt the connection. | ||||
| func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) error { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnConnect) { | ||||
| 			hook.OnConnect(cl, pk) | ||||
| 			err := hook.OnConnect(cl, pk) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | ||||
| func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnSessionEstablished) { | ||||
| 			hook.OnSessionEstablished(cl, pk) | ||||
| 		} | ||||
| @@ -219,7 +240,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | ||||
|  | ||||
| // OnDisconnect is called when a client is disconnected for any reason. | ||||
| func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnDisconnect) { | ||||
| 			hook.OnDisconnect(cl, err, expire) | ||||
| 		} | ||||
| @@ -229,7 +250,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | ||||
| // OnPacketRead is called when a packet is received from a client. | ||||
| func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPacketRead) { | ||||
| 			npk, err := hook.OnPacketRead(cl, pkx) | ||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||
| @@ -250,7 +271,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, | ||||
| // to create their own auth packet handling mechanisms. | ||||
| func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnAuthPacket) { | ||||
| 			npk, err := hook.OnAuthPacket(cl, pkx) | ||||
| 			if err != nil { | ||||
| @@ -266,7 +287,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, | ||||
|  | ||||
| // OnPacketEncode is called immediately before a packet is encoded to be sent to a client. | ||||
| func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPacketEncode) { | ||||
| 			pk = hook.OnPacketEncode(cl, pk) | ||||
| 		} | ||||
| @@ -277,7 +298,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | ||||
|  | ||||
| // OnPacketProcessed is called when a packet has been received and successfully handled by the broker. | ||||
| func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPacketProcessed) { | ||||
| 			hook.OnPacketProcessed(cl, pk, err) | ||||
| 		} | ||||
| @@ -287,7 +308,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | ||||
| // OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter | ||||
| // containing the bytes sent. | ||||
| func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPacketSent) { | ||||
| 			hook.OnPacketSent(cl, pk, b) | ||||
| 		} | ||||
| @@ -299,7 +320,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | ||||
| // before the packet is processed. The return values of the hook methods are passed-through | ||||
| // in the order the hooks were attached. | ||||
| func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnSubscribe) { | ||||
| 			pk = hook.OnSubscribe(cl, pk) | ||||
| 		} | ||||
| @@ -309,7 +330,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
|  | ||||
| // OnSubscribed is called when a client subscribes to one or more filters. | ||||
| func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnSubscribed) { | ||||
| 			hook.OnSubscribed(cl, pk, reasonCodes) | ||||
| 		} | ||||
| @@ -321,7 +342,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) | ||||
| // remove or add clients to a publish to subscribers process, or to select the subscriber for a shared | ||||
| // group in a custom manner (such as based on client id, ip, etc). | ||||
| func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnSelectSubscribers) { | ||||
| 			subs = hook.OnSelectSubscribers(subs, pk) | ||||
| 		} | ||||
| @@ -334,7 +355,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc | ||||
| // before the packet is processed. The return values of the hook methods are passed-through | ||||
| // in the order the hooks were attached. | ||||
| func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnUnsubscribe) { | ||||
| 			pk = hook.OnUnsubscribe(cl, pk) | ||||
| 		} | ||||
| @@ -344,7 +365,7 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
|  | ||||
| // OnUnsubscribed is called when a client unsubscribes from one or more filters. | ||||
| func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnUnsubscribed) { | ||||
| 			hook.OnUnsubscribed(cl, pk) | ||||
| 		} | ||||
| @@ -356,16 +377,17 @@ func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { | ||||
| // The return values of the hook methods are passed-through in the order the hooks were attached. | ||||
| func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPublish) { | ||||
| 			npk, err := hook.OnPublish(cl, pkx) | ||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected") | ||||
| 			if err != nil { | ||||
| 				if errors.Is(err, packets.ErrRejectPacket) { | ||||
| 					h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected") | ||||
| 					return pk, err | ||||
| 				} | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet error") | ||||
| 				return pk, err | ||||
| 			} else if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			pkx = npk | ||||
| 		} | ||||
| 	} | ||||
| @@ -375,27 +397,46 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er | ||||
|  | ||||
| // OnPublished is called when a client has published a message to subscribers. | ||||
| func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPublished) { | ||||
| 			hook.OnPublished(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnPublishDropped is called when a message to a client was dropped instead of delivered | ||||
| // such as when a client is too slow to respond. | ||||
| func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPublishDropped) { | ||||
| 			hook.OnPublishDropped(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainMessage is called then a published message is retained. | ||||
| func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnRetainMessage) { | ||||
| 			hook.OnRetainMessage(cl, pk, r) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainPublished is called when a retained message is published. | ||||
| func (h *Hooks) OnRetainPublished(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnRetainPublished) { | ||||
| 			hook.OnRetainPublished(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber. | ||||
| // In other words, this method is called when a new inflight message is created or resent. | ||||
| // It is typically used to store a new inflight message. | ||||
| func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnQosPublish) { | ||||
| 			hook.OnQosPublish(cl, pk, sent, resends) | ||||
| 		} | ||||
| @@ -406,7 +447,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends | ||||
| // In other words, when an inflight message is resolved. | ||||
| // It is typically used to delete an inflight message from a store. | ||||
| func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnQosComplete) { | ||||
| 			hook.OnQosComplete(cl, pk) | ||||
| 		} | ||||
| @@ -414,22 +455,32 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | ||||
| } | ||||
|  | ||||
| // OnQosDropped is called the Qos flow for a message expires. In other words, when | ||||
| // an inflight message expires or is abandoned. | ||||
| // It is typically used to delete an inflight message from a store. | ||||
| // an inflight message expires or is abandoned. It is typically used to delete an | ||||
| // inflight message from a store. | ||||
| func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnQosDropped) { | ||||
| 			hook.OnQosDropped(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnPacketIDExhausted is called when the client runs out of unused packet ids to | ||||
| // assign to a packet. | ||||
| func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnPacketIDExhausted) { | ||||
| 			hook.OnPacketIDExhausted(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnWill is called when a client disconnects and publishes an LWT message. This method | ||||
| // differs from OnWillSent in that it allows you to modify the LWT message before it is | ||||
| // published. The return values of the hook methods are passed-through in the order | ||||
| // the hooks were attached. | ||||
| func (h *Hooks) OnWill(cl *Client, will Will) Will { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnWill) { | ||||
| 			mlwt, err := hook.OnWill(cl, will) | ||||
| 			if err != nil { | ||||
| @@ -445,7 +496,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will { | ||||
|  | ||||
| // OnWillSent is called when an LWT message has been issued from a disconnecting client. | ||||
| func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnWillSent) { | ||||
| 			hook.OnWillSent(cl, pk) | ||||
| 		} | ||||
| @@ -454,7 +505,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | ||||
|  | ||||
| // OnClientExpired is called when a client session has expired and should be deleted. | ||||
| func (h *Hooks) OnClientExpired(cl *Client) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnClientExpired) { | ||||
| 			hook.OnClientExpired(cl) | ||||
| 		} | ||||
| @@ -463,7 +514,7 @@ func (h *Hooks) OnClientExpired(cl *Client) { | ||||
|  | ||||
| // OnRetainedExpired is called when a retained message has expired and should be deleted. | ||||
| func (h *Hooks) OnRetainedExpired(filter string) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnRetainedExpired) { | ||||
| 			hook.OnRetainedExpired(filter) | ||||
| 		} | ||||
| @@ -473,7 +524,7 @@ func (h *Hooks) OnRetainedExpired(filter string) { | ||||
| // StoredClients returns all clients, e.g. from a persistent store, is used to | ||||
| // populate the server clients list before start. | ||||
| func (h *Hooks) StoredClients() (v []storage.Client, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(StoredClients) { | ||||
| 			v, err := hook.StoredClients() | ||||
| 			if err != nil { | ||||
| @@ -493,7 +544,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) { | ||||
| // StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is | ||||
| // used to populate the server subscriptions list before start. | ||||
| func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(StoredSubscriptions) { | ||||
| 			v, err := hook.StoredSubscriptions() | ||||
| 			if err != nil { | ||||
| @@ -513,7 +564,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| // StoredInflightMessages returns all inflight messages, e.g. from a persistent store, | ||||
| // and is used to populate the restored clients with inflight messages before start. | ||||
| func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(StoredInflightMessages) { | ||||
| 			v, err := hook.StoredInflightMessages() | ||||
| 			if err != nil { | ||||
| @@ -533,7 +584,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| // StoredRetainedMessages returns all retained messages, e.g. from a persistent store, | ||||
| // and is used to populate the server topics with retained messages before start. | ||||
| func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(StoredRetainedMessages) { | ||||
| 			v, err := hook.StoredRetainedMessages() | ||||
| 			if err != nil { | ||||
| @@ -552,7 +603,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
|  | ||||
| // StoredSysInfo returns a set of system info values. | ||||
| func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(StoredSysInfo) { | ||||
| 			v, err := hook.StoredSysInfo() | ||||
| 			if err != nil { | ||||
| @@ -574,7 +625,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| // server (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||
| // check connecting users against an existing user database. | ||||
| func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnConnectAuthenticate) { | ||||
| 			if ok := hook.OnConnectAuthenticate(cl, pk); ok { | ||||
| 				return true | ||||
| @@ -590,7 +641,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| // (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||
| // check publishing and subscribing users against an existing permissions or roles database. | ||||
| func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 	for _, hook := range h.GetAll() { | ||||
| 		if hook.Provides(OnACLCheck) { | ||||
| 			if ok := hook.OnACLCheck(cl, topic, write); ok { | ||||
| 				return true | ||||
| @@ -601,19 +652,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnExpireInflights is called when the server issues a clear request for expired | ||||
| // inflight messages. Expiry should be the time after which the message is no longer | ||||
| // valid (usually some time in the past). A message has expired if it's created time | ||||
| // is older than time.Now() minus Inflight TTL. This method can be used to expire | ||||
| // old inflight messages in a persistent store which doesnt support per-item TTL. | ||||
| func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnExpireInflights) { | ||||
| 			hook.OnExpireInflights(cl, expiry) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // HookBase provides a set of default methods for each hook. It should be embedded in | ||||
| // all hooks. | ||||
| type HookBase struct { | ||||
| @@ -646,7 +684,7 @@ func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) { | ||||
| 	h.Opts = opts | ||||
| } | ||||
|  | ||||
| // Stop is called to gracefully shutdown the hook. | ||||
| // Stop is called to gracefully shut down the hook. | ||||
| func (h *HookBase) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
| @@ -671,7 +709,9 @@ func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| } | ||||
|  | ||||
| // OnConnect is called when a new client connects. | ||||
| func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {} | ||||
| func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | ||||
| func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {} | ||||
| @@ -729,9 +769,15 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err | ||||
| // OnPublished is called when a client has published a message to subscribers. | ||||
| func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnPublishDropped is called when a message to a client is dropped instead of being delivered. | ||||
| func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnRetainMessage is called then a published message is retained. | ||||
| func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {} | ||||
|  | ||||
| // OnRetainPublished is called when a retained message is published. | ||||
| func (h *HookBase) OnRetainPublished(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber. | ||||
| func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {} | ||||
|  | ||||
| @@ -741,6 +787,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {} | ||||
| // OnQosDropped is called the Qos flow for a message expires. | ||||
| func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet. | ||||
| func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnWill is called when a client disconnects and publishes an LWT message. | ||||
| func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) { | ||||
| 	return will, nil | ||||
| @@ -755,9 +804,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {} | ||||
| // OnRetainedExpired is called when a retained message for a topic has expired. | ||||
| func (h *HookBase) OnRetainedExpired(topic string) {} | ||||
|  | ||||
| // OnExpireInflights is called when the server issues a clear request for expired inflight messages. | ||||
| func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {} | ||||
|  | ||||
| // StoredClients returns all clients from a store. | ||||
| func (h *HookBase) StoredClients() (v []storage.Client, err error) { | ||||
| 	return | ||||
|   | ||||
| @@ -80,7 +80,6 @@ func (h *Hook) Provides(b byte) bool { | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| @@ -183,6 +182,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if cl.StopCause() == packets.ErrSessionTakenOver { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data") | ||||
| @@ -199,11 +202,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:                 storage.SubscriptionKey, | ||||
| 			Client:            cl.ID, | ||||
| 			Qos:               reasonCodes[i], | ||||
| 			Filter:            pk.Filters[i].Filter, | ||||
| 			Identifier:        pk.Filters[i].Identifier, | ||||
| 			NoLocal:           pk.Filters[i].NoLocal, | ||||
| 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||
| 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||
| 		} | ||||
|  | ||||
| 		err := h.db.Upsert(in.ID, in) | ||||
| @@ -348,32 +355,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, m := range v { | ||||
| 		if m.Created < expiry || m.Created == 0 { | ||||
| 			err := h.db.Delete(m.ID, new(storage.Message)) | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	err := h.db.Delete(retainedKey(filter), new(storage.Message)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") | ||||
| @@ -382,6 +370,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") | ||||
|   | ||||
| @@ -5,13 +5,11 @@ | ||||
| package badger | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/asdine/storm/v3" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| @@ -170,6 +168,21 @@ func TestOnClientExpired(t *testing.T) { | ||||
| 	require.ErrorIs(t, badgerhold.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -219,6 +232,29 @@ func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectSessionTakenOver(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	testClient := &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	testClient.Stop(packets.ErrSessionTakenOver) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnDisconnect(testClient, nil, true) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -333,6 +369,21 @@ func TestOnRetainedExpired(t *testing.T) { | ||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||
| } | ||||
|  | ||||
| func TestOnRetainExpiredNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainExpiredClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -419,48 +470,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	require.Len(t, v, 1) | ||||
| 	require.Equal(t, "i1", v[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|   | ||||
| @@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool { | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| @@ -185,6 +184,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if cl.StopCause() == packets.ErrSessionTakenOver { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") | ||||
| @@ -201,12 +204,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:                 storage.SubscriptionKey, | ||||
| 			Client:            cl.ID, | ||||
| 			Qos:               reasonCodes[i], | ||||
| 			Filter:            pk.Filters[i].Filter, | ||||
| 			Identifier:        pk.Filters[i].Identifier, | ||||
| 			NoLocal:           pk.Filters[i].NoLocal, | ||||
| 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||
| 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||
| 		} | ||||
|  | ||||
| 		err := h.db.Save(in) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err). | ||||
| @@ -369,34 +377,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the | ||||
| // provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err := h.db.Find("T", storage.InflightKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, m := range v { | ||||
| 		if m.Created < expiry || m.Created == 0 { | ||||
| 			err := h.db.DeleteStruct(&storage.Message{ID: m.ID}) | ||||
| 			if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 				h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data") | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") | ||||
| 	} | ||||
| @@ -404,6 +391,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||
|   | ||||
| @@ -5,7 +5,6 @@ | ||||
| package bolt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| @@ -212,6 +211,21 @@ func TestOnClientExpired(t *testing.T) { | ||||
| 	require.ErrorIs(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -227,6 +241,29 @@ func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectSessionTakenOver(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	testClient := &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	testClient.Stop(packets.ErrSessionTakenOver) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnDisconnect(testClient, nil, true) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -341,6 +378,21 @@ func TestOnRetainedExpired(t *testing.T) { | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpiredClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpiredNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| @@ -427,48 +479,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err = h.db.Find("T", storage.InflightKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	require.Len(t, v, 1) | ||||
| 	require.Equal(t, "i1", v[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|   | ||||
| @@ -83,7 +83,6 @@ func (h *Hook) Provides(b byte) bool { | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| @@ -200,6 +199,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if cl.StopCause() == packets.ErrSessionTakenOver { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") | ||||
| @@ -216,11 +219,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 			ID:                subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:                 storage.SubscriptionKey, | ||||
| 			Client:            cl.ID, | ||||
| 			Qos:               reasonCodes[i], | ||||
| 			Filter:            pk.Filters[i].Filter, | ||||
| 			Identifier:        pk.Filters[i].Identifier, | ||||
| 			NoLocal:           pk.Filters[i].NoLocal, | ||||
| 			RetainHandling:    pk.Filters[i].RetainHandling, | ||||
| 			RetainAsPublished: pk.Filters[i].RetainAsPublished, | ||||
| 		} | ||||
|  | ||||
| 		err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() | ||||
| @@ -364,37 +371,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the | ||||
| // provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") | ||||
| 		} | ||||
|  | ||||
| 		if d.Created < expiry || d.Created == 0 { | ||||
| 			err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") | ||||
| @@ -403,6 +386,11 @@ func (h *Hook) OnRetainedExpired(filter string) { | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||
|   | ||||
| @@ -253,6 +253,22 @@ func TestOnClientExpired(t *testing.T) { | ||||
| 	require.ErrorIs(t, redis.Nil, err) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpiredNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnClientExpired(client) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| @@ -269,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectSessionTakenOver(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
|  | ||||
| 	testClient := &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	testClient.Stop(packets.ErrSessionTakenOver) | ||||
| 	teardown(t, h) | ||||
| 	h.OnDisconnect(testClient, nil, true) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| @@ -392,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) { | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpiredClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpiredNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnRetainedExpired("a/b/c") | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| @@ -484,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	n := time.Now().Unix() | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", | ||||
| 		&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", | ||||
| 		&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", | ||||
| 		&storage.Message{ID: "i3", T: storage.InflightKey}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var r []storage.Message | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, rows, 1) | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		err = d.UnmarshalBinary([]byte(row)) | ||||
| 		require.NoError(t, err) | ||||
| 		r = append(r, d) | ||||
| 	} | ||||
| 	require.Len(t, r, 1) | ||||
| 	require.Equal(t, "i1", r[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
|   | ||||
| @@ -117,6 +117,36 @@ func (d *Message) UnmarshalBinary(data []byte) error { | ||||
| 	return json.Unmarshal(data, d) | ||||
| } | ||||
|  | ||||
| // ToPacket converts a storage.Message to a standard packet. | ||||
| func (d *Message) ToPacket() packets.Packet { | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: d.FixedHeader, | ||||
| 		PacketID:    d.PacketID, | ||||
| 		TopicName:   d.TopicName, | ||||
| 		Payload:     d.Payload, | ||||
| 		Origin:      d.Origin, | ||||
| 		Created:     d.Created, | ||||
| 		Properties: packets.Properties{ | ||||
| 			PayloadFormat:          d.Properties.PayloadFormat, | ||||
| 			PayloadFormatFlag:      d.Properties.PayloadFormatFlag, | ||||
| 			MessageExpiryInterval:  d.Properties.MessageExpiryInterval, | ||||
| 			ContentType:            d.Properties.ContentType, | ||||
| 			ResponseTopic:          d.Properties.ResponseTopic, | ||||
| 			CorrelationData:        d.Properties.CorrelationData, | ||||
| 			SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, | ||||
| 			TopicAlias:             d.Properties.TopicAlias, | ||||
| 			User:                   d.Properties.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Return a deep copy of the packet data otherwise the slices will | ||||
| 	// continue pointing at the values from the storage packet. | ||||
| 	pk = pk.Copy(true) | ||||
| 	pk.FixedHeader.Dup = d.FixedHeader.Dup | ||||
|  | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // Subscription is a storable representation of an mqtt subscription. | ||||
| type Subscription struct { | ||||
| 	T                 string `json:"t"` | ||||
|   | ||||
| @@ -104,6 +104,7 @@ var ( | ||||
| 			ClientsMaximum:   7, | ||||
| 			MessagesReceived: 10, | ||||
| 			MessagesSent:     11, | ||||
| 			MessagesDropped:  20, | ||||
| 			PacketsReceived:  12, | ||||
| 			PacketsSent:      13, | ||||
| 			Retained:         15, | ||||
| @@ -111,7 +112,7 @@ var ( | ||||
| 			InflightDropped:  17, | ||||
| 		}, | ||||
| 	} | ||||
| 	sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`) | ||||
| 	sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`) | ||||
| ) | ||||
|  | ||||
| func TestClientMarshalBinary(t *testing.T) { | ||||
| @@ -193,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) { | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, SystemInfo{}, d) | ||||
| } | ||||
|  | ||||
| func TestMessageToPacket(t *testing.T) { | ||||
| 	d := messageStruct | ||||
| 	pk := d.ToPacket() | ||||
|  | ||||
| 	require.Equal(t, packets.Packet{ | ||||
| 		Payload: []byte("payload"), | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Remaining: d.FixedHeader.Remaining, | ||||
| 			Type:      d.FixedHeader.Type, | ||||
| 			Qos:       d.FixedHeader.Qos, | ||||
| 			Dup:       d.FixedHeader.Dup, | ||||
| 			Retain:    d.FixedHeader.Retain, | ||||
| 		}, | ||||
| 		Origin:    d.Origin, | ||||
| 		TopicName: d.TopicName, | ||||
| 		Properties: packets.Properties{ | ||||
| 			PayloadFormat:          d.Properties.PayloadFormat, | ||||
| 			PayloadFormatFlag:      d.Properties.PayloadFormatFlag, | ||||
| 			MessageExpiryInterval:  d.Properties.MessageExpiryInterval, | ||||
| 			ContentType:            d.Properties.ContentType, | ||||
| 			ResponseTopic:          d.Properties.ResponseTopic, | ||||
| 			CorrelationData:        d.Properties.CorrelationData, | ||||
| 			SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, | ||||
| 			TopicAlias:             d.Properties.TopicAlias, | ||||
| 			User:                   d.Properties.User, | ||||
| 		}, | ||||
| 		PacketID: 100, | ||||
| 		Created:  d.Created, | ||||
| 	}, pk) | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -27,6 +27,10 @@ type modifiedHookBase struct { | ||||
|  | ||||
| var errTestHook = errors.New("error") | ||||
|  | ||||
| func (h *modifiedHookBase) ID() string { | ||||
| 	return "modified" | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) Init(config any) error { | ||||
| 	if config != nil { | ||||
| 		return errTestHook | ||||
| @@ -46,6 +50,14 @@ func (h *modifiedHookBase) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnConnect(cl *Client, pk packets.Packet) error { | ||||
| 	if h.fail { | ||||
| 		return errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| 	return true | ||||
| } | ||||
| @@ -178,12 +190,20 @@ func TestHooksProvides(t *testing.T) { | ||||
| 	require.False(t, h.Provides(OnDisconnect)) | ||||
| } | ||||
|  | ||||
| func TestHooksAddAndLen(t *testing.T) { | ||||
| func TestHooksAddLenGetAll(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(HookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(1), atomic.LoadInt64(&h.qty)) | ||||
| 	require.Equal(t, int64(1), h.Len()) | ||||
|  | ||||
| 	err = h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, int64(2), atomic.LoadInt64(&h.qty)) | ||||
| 	require.Equal(t, int64(2), h.Len()) | ||||
|  | ||||
| 	all := h.GetAll() | ||||
| 	require.Equal(t, "base", all[0].ID()) | ||||
| 	require.Equal(t, "modified", all[1].ID()) | ||||
| } | ||||
|  | ||||
| func TestHooksAddInitFailure(t *testing.T) { | ||||
| @@ -216,7 +236,6 @@ func TestHooksNonReturns(t *testing.T) { | ||||
| 			h.OnStarted() | ||||
| 			h.OnStopped() | ||||
| 			h.OnSysInfoTick(new(system.Info)) | ||||
| 			h.OnConnect(cl, packets.Packet{}) | ||||
| 			h.OnSessionEstablished(cl, packets.Packet{}) | ||||
| 			h.OnDisconnect(cl, nil, false) | ||||
| 			h.OnPacketSent(cl, packets.Packet{}, []byte{}) | ||||
| @@ -224,14 +243,16 @@ func TestHooksNonReturns(t *testing.T) { | ||||
| 			h.OnSubscribed(cl, packets.Packet{}, []byte{1}) | ||||
| 			h.OnUnsubscribed(cl, packets.Packet{}) | ||||
| 			h.OnPublished(cl, packets.Packet{}) | ||||
| 			h.OnPublishDropped(cl, packets.Packet{}) | ||||
| 			h.OnRetainMessage(cl, packets.Packet{}, 0) | ||||
| 			h.OnRetainPublished(cl, packets.Packet{}) | ||||
| 			h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0) | ||||
| 			h.OnQosComplete(cl, packets.Packet{}) | ||||
| 			h.OnQosDropped(cl, packets.Packet{}) | ||||
| 			h.OnPacketIDExhausted(cl, packets.Packet{}) | ||||
| 			h.OnWillSent(cl, packets.Packet{}) | ||||
| 			h.OnClientExpired(cl) | ||||
| 			h.OnRetainedExpired("a/b/c") | ||||
| 			h.OnExpireInflights(cl, time.Now().Unix()-1) | ||||
|  | ||||
| 			// on second iteration, check added hook methods | ||||
| 			err := h.Add(new(modifiedHookBase), nil) | ||||
| @@ -325,7 +346,7 @@ func TestHooksOnPublish(t *testing.T) { | ||||
| 	// coverage: failure | ||||
| 	hook.fail = true | ||||
| 	pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	// coverage: reject packet | ||||
| @@ -380,6 +401,22 @@ func TestHooksOnAuthPacket(t *testing.T) { | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHooksOnConnect(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.OnConnect(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	err = h.OnConnect(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestHooksOnPacketEncode(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
| @@ -552,12 +589,19 @@ func TestHookBaseOnConnectAuthenticate(t *testing.T) { | ||||
| 	v := h.OnConnectAuthenticate(new(Client), packets.Packet{}) | ||||
| 	require.False(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnACLCheck(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v := h.OnACLCheck(new(Client), "topic", true) | ||||
| 	require.False(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnConnect(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	err := h.OnConnect(new(Client), packets.Packet{}) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnPublish(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
|   | ||||
							
								
								
									
										24
									
								
								inflight.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								inflight.go
									
									
									
									
									
								
							| @@ -58,6 +58,18 @@ func (i *Inflight) Len() int { | ||||
| 	return len(i.internal) | ||||
| } | ||||
|  | ||||
| // Clone returns a new instance of Inflight with the same message data. | ||||
| // This is used when transferring inflights from a taken-over session. | ||||
| func (i *Inflight) Clone() *Inflight { | ||||
| 	c := NewInflights() | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
| 	for k, v := range i.internal { | ||||
| 		c.internal[k] = v | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| // GetAll returns all the inflight messages. | ||||
| func (i *Inflight) GetAll(immediate bool) []packets.Packet { | ||||
| 	i.RLock() | ||||
| @@ -104,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool { | ||||
| } | ||||
|  | ||||
| // TakeRecieveQuota reduces the receive quota by 1. | ||||
| func (i *Inflight) TakeReceiveQuota() { | ||||
| func (i *Inflight) DecreaseReceiveQuota() { | ||||
| 	if atomic.LoadInt32(&i.receiveQuota) > 0 { | ||||
| 		atomic.AddInt32(&i.receiveQuota, -1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TakeRecieveQuota increases the receive quota by 1. | ||||
| func (i *Inflight) ReturnReceiveQuota() { | ||||
| func (i *Inflight) IncreaseReceiveQuota() { | ||||
| 	if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { | ||||
| 		atomic.AddInt32(&i.receiveQuota, 1) | ||||
| 	} | ||||
| @@ -123,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) { | ||||
| 	atomic.StoreInt32(&i.maximumReceiveQuota, n) | ||||
| } | ||||
|  | ||||
| // TakeSendQuota reduces the send quota by 1. | ||||
| func (i *Inflight) TakeSendQuota() { | ||||
| // DecreaseSendQuota reduces the send quota by 1. | ||||
| func (i *Inflight) DecreaseSendQuota() { | ||||
| 	if atomic.LoadInt32(&i.sendQuota) > 0 { | ||||
| 		atomic.AddInt32(&i.sendQuota, -1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReturnSendQuota increases the send quota by 1. | ||||
| func (i *Inflight) ReturnSendQuota() { | ||||
| // IncreaseSendQuota increases the send quota by 1. | ||||
| func (i *Inflight) IncreaseSendQuota() { | ||||
| 	if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { | ||||
| 		atomic.AddInt32(&i.sendQuota, 1) | ||||
| 	} | ||||
|   | ||||
| @@ -61,6 +61,16 @@ func TestInflightLen(t *testing.T) { | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
| } | ||||
|  | ||||
| func TestInflightClone(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
|  | ||||
| 	cloned := cl.State.Inflight.Clone() | ||||
| 	require.NotNil(t, cloned) | ||||
| 	require.NotSame(t, cloned, cl.State.Inflight) | ||||
| } | ||||
|  | ||||
| func TestInflightDelete(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| @@ -95,12 +105,12 @@ func TestReceiveQuota(t *testing.T) { | ||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Return 1 | ||||
| 	i.ReturnReceiveQuota() | ||||
| 	i.IncreaseReceiveQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Try to go over max limit | ||||
| 	i.ReturnReceiveQuota() | ||||
| 	i.IncreaseReceiveQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| @@ -110,12 +120,12 @@ func TestReceiveQuota(t *testing.T) { | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Take 1 | ||||
| 	i.TakeReceiveQuota() | ||||
| 	i.DecreaseReceiveQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Try to go below zero | ||||
| 	i.TakeReceiveQuota() | ||||
| 	i.DecreaseReceiveQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||
| } | ||||
| @@ -137,12 +147,12 @@ func TestSendQuota(t *testing.T) { | ||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Return 1 | ||||
| 	i.ReturnSendQuota() | ||||
| 	i.IncreaseSendQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Try to go over max limit | ||||
| 	i.ReturnSendQuota() | ||||
| 	i.IncreaseSendQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| @@ -152,12 +162,12 @@ func TestSendQuota(t *testing.T) { | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Take 1 | ||||
| 	i.TakeSendQuota() | ||||
| 	i.DecreaseSendQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Try to go below zero | ||||
| 	i.TakeSendQuota() | ||||
| 	i.DecreaseSendQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||
| } | ||||
|   | ||||
							
								
								
									
										104
									
								
								listeners/http_healthcheck.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								listeners/http_healthcheck.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2023 mochi-co | ||||
| // SPDX-FileContributor: Derek Duncan | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. | ||||
| type HTTPHealthCheck struct { | ||||
| 	sync.RWMutex | ||||
| 	id      string          // the internal id of the listener | ||||
| 	address string          // the network address to bind to | ||||
| 	config  *Config         // configuration values for the listener | ||||
| 	listen  *http.Server    // the http server | ||||
| 	log     *zerolog.Logger // server logger | ||||
| 	end     uint32          // ensure the close methods are only called once | ||||
| } | ||||
|  | ||||
| // NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address. | ||||
| func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck { | ||||
| 	if config == nil { | ||||
| 		config = new(Config) | ||||
| 	} | ||||
| 	return &HTTPHealthCheck{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 		config:  config, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *HTTPHealthCheck) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *HTTPHealthCheck) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
|  | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *HTTPHealthCheck) Protocol() string { | ||||
| 	if l.listen != nil && l.listen.TLSConfig != nil { | ||||
| 		return "https" | ||||
| 	} | ||||
|  | ||||
| 	return "http" | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *HTTPHealthCheck) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
|  | ||||
| 	mux := http.NewServeMux() | ||||
| 	mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if r.Method != http.MethodGet { | ||||
| 			w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 		} | ||||
| 	}) | ||||
| 	l.listen = &http.Server{ | ||||
| 		ReadTimeout:  5 * time.Second, | ||||
| 		WriteTimeout: 5 * time.Second, | ||||
| 		Addr:         l.address, | ||||
| 		Handler:      mux, | ||||
| 	} | ||||
|  | ||||
| 	if l.config.TLSConfig != nil { | ||||
| 		l.listen.TLSConfig = l.config.TLSConfig | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Serve starts listening for new connections and serving responses. | ||||
| func (l *HTTPHealthCheck) Serve(establish EstablishFn) { | ||||
| 	if l.listen.TLSConfig != nil { | ||||
| 		l.listen.ListenAndServeTLS("", "") | ||||
| 	} else { | ||||
| 		l.listen.ListenAndServe() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *HTTPHealthCheck) Close(closeClients CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 		defer cancel() | ||||
| 		l.listen.Shutdown(ctx) | ||||
| 	} | ||||
|  | ||||
| 	closeClients(l.id) | ||||
| } | ||||
							
								
								
									
										143
									
								
								listeners/http_healthcheck_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								listeners/http_healthcheck_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2023 mochi-co | ||||
| // SPDX-FileContributor: Derek Duncan | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewHTTPHealthCheck(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	require.Equal(t, "healthcheck", l.id) | ||||
| 	require.Equal(t, testAddr, l.address) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckID(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	require.Equal(t, "healthcheck", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckAddress(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	require.Equal(t, testAddr, l.Address()) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckProtocol(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	require.Equal(t, "http", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckTLSProtocol(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
|  | ||||
| 	l.Init(nil) | ||||
| 	require.Equal(t, "https", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckInit(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.NotNil(t, l.listen) | ||||
| 	require.Equal(t, testAddr, l.listen.Addr) | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckServeAndClose(t *testing.T) { | ||||
| 	// setup http stats listener | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	// call healthcheck | ||||
| 	resp, err := http.Get("http://localhost" + testAddr + "/healthcheck") | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, resp) | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	_, err = io.ReadAll(resp.Body) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// ensure listening is closed | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.Equal(t, true, closed) | ||||
|  | ||||
| 	_, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck") | ||||
| 	require.Error(t, err) | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { | ||||
| 	// setup http stats listener | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, nil) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	// make disallowed method type http request | ||||
| 	resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, resp) | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	_, err = io.ReadAll(resp.Body) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// ensure listening is closed | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.Equal(t, true, closed) | ||||
|  | ||||
| 	_, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody) | ||||
| 	require.Error(t, err) | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { | ||||
| 	l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
|  | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	l.Close(MockCloser) | ||||
| } | ||||
| @@ -107,28 +107,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) { | ||||
|  | ||||
| // jsonHandler is an HTTP handler which outputs the $SYS stats as JSON. | ||||
| func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { | ||||
| 	info := &system.Info{ | ||||
| 		Version:             l.sysInfo.Version, | ||||
| 		Started:             atomic.LoadInt64(&l.sysInfo.Started), | ||||
| 		Time:                atomic.LoadInt64(&l.sysInfo.Time), | ||||
| 		Uptime:              atomic.LoadInt64(&l.sysInfo.Uptime), | ||||
| 		BytesReceived:       atomic.LoadInt64(&l.sysInfo.BytesReceived), | ||||
| 		BytesSent:           atomic.LoadInt64(&l.sysInfo.BytesSent), | ||||
| 		ClientsConnected:    atomic.LoadInt64(&l.sysInfo.ClientsConnected), | ||||
| 		ClientsMaximum:      atomic.LoadInt64(&l.sysInfo.ClientsMaximum), | ||||
| 		ClientsTotal:        atomic.LoadInt64(&l.sysInfo.ClientsTotal), | ||||
| 		ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected), | ||||
| 		MessagesReceived:    atomic.LoadInt64(&l.sysInfo.MessagesReceived), | ||||
| 		MessagesSent:        atomic.LoadInt64(&l.sysInfo.MessagesSent), | ||||
| 		InflightDropped:     atomic.LoadInt64(&l.sysInfo.InflightDropped), | ||||
| 		Subscriptions:       atomic.LoadInt64(&l.sysInfo.Subscriptions), | ||||
| 		PacketsReceived:     atomic.LoadInt64(&l.sysInfo.PacketsReceived), | ||||
| 		PacketsSent:         atomic.LoadInt64(&l.sysInfo.PacketsSent), | ||||
| 		Retained:            atomic.LoadInt64(&l.sysInfo.Retained), | ||||
| 		Inflight:            atomic.LoadInt64(&l.sysInfo.Inflight), | ||||
| 		MemoryAlloc:         atomic.LoadInt64(&l.sysInfo.MemoryAlloc), | ||||
| 		Threads:             atomic.LoadInt64(&l.sysInfo.Threads), | ||||
| 	} | ||||
| 	info := *l.sysInfo.Clone() | ||||
|  | ||||
| 	out, err := json.MarshalIndent(info, "", "\t") | ||||
| 	if err != nil { | ||||
|   | ||||
							
								
								
									
										92
									
								
								listeners/net.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								listeners/net.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2023 mochi-co | ||||
| // SPDX-FileContributor: Jeroen Rinzema | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // Net is a listener for establishing client connections on basic TCP protocol. | ||||
| type Net struct { // [MQTT-4.2.0-1] | ||||
| 	mu       sync.Mutex | ||||
| 	listener net.Listener    // a net.Listener which will listen for new clients | ||||
| 	id       string          // the internal id of the listener | ||||
| 	log      *zerolog.Logger // server logger | ||||
| 	end      uint32          // ensure the close methods are only called once | ||||
| } | ||||
|  | ||||
| // NewNet initialises and returns a listener serving incoming connections on the given net.Listener | ||||
| func NewNet(id string, listener net.Listener) *Net { | ||||
| 	return &Net{ | ||||
| 		id:       id, | ||||
| 		listener: listener, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *Net) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *Net) Address() string { | ||||
| 	return l.listener.Addr().String() | ||||
| } | ||||
|  | ||||
| // Protocol returns the network of the listener. | ||||
| func (l *Net) Protocol() string { | ||||
| 	return l.listener.Addr().Network() | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *Net) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Serve starts waiting for new TCP connections, and calls the establish | ||||
| // connection callback for any received. | ||||
| func (l *Net) Serve(establish EstablishFn) { | ||||
| 	for { | ||||
| 		if atomic.LoadUint32(&l.end) == 1 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		conn, err := l.listener.Accept() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if atomic.LoadUint32(&l.end) == 0 { | ||||
| 			go func() { | ||||
| 				err = establish(l.id, conn) | ||||
| 				if err != nil { | ||||
| 					l.log.Warn().Err(err).Send() | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *Net) Close(closeClients CloseFn) { | ||||
| 	l.mu.Lock() | ||||
| 	defer l.mu.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		closeClients(l.id) | ||||
| 	} | ||||
|  | ||||
| 	if l.listener != nil { | ||||
| 		err := l.listener.Close() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										105
									
								
								listeners/net_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								listeners/net_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewNet(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| } | ||||
|  | ||||
| func TestNetID(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestNetAddress(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	require.Equal(t, n.Addr().String(), l.Address()) | ||||
| } | ||||
|  | ||||
| func TestNetProtocol(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	require.Equal(t, "tcp", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestNetInit(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	err = l.Init(&logger) | ||||
| 	l.Close(MockCloser) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestNetServeAndClose(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	err = l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.True(t, closed) | ||||
| 	<-o | ||||
|  | ||||
| 	l.Close(MockCloser)      // coverage: close closed | ||||
| 	l.Serve(MockEstablisher) // coverage: serve closed | ||||
| } | ||||
|  | ||||
| func TestNetEstablishThenEnd(t *testing.T) { | ||||
| 	n, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l := NewNet("t1", n) | ||||
| 	err = l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	established := make(chan bool) | ||||
| 	go func() { | ||||
| 		l.Serve(func(id string, c net.Conn) error { | ||||
| 			established <- true | ||||
| 			return errors.New("ending") // return an error to exit immediately | ||||
| 		}) | ||||
| 		o <- true | ||||
| 	}() | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	net.Dial("tcp", n.Addr().String()) | ||||
| 	require.Equal(t, true, <-established) | ||||
| 	l.Close(MockCloser) | ||||
| 	<-o | ||||
| } | ||||
							
								
								
									
										98
									
								
								listeners/unixsock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								listeners/unixsock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,98 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: jason@zgwit.com | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // UnixSock is a listener for establishing client connections on basic UnixSock protocol. | ||||
| type UnixSock struct { | ||||
| 	sync.RWMutex | ||||
| 	id      string          // the internal id of the listener. | ||||
| 	address string          // the network address to bind to. | ||||
| 	listen  net.Listener    // a net.Listener which will listen for new clients. | ||||
| 	log     *zerolog.Logger // server logger | ||||
| 	end     uint32          // ensure the close methods are only called once. | ||||
| } | ||||
|  | ||||
| // NewUnixSock initialises and returns a new UnixSock listener, listening on an address. | ||||
| func NewUnixSock(id, address string) *UnixSock { | ||||
| 	return &UnixSock{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *UnixSock) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *UnixSock) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
|  | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *UnixSock) Protocol() string { | ||||
| 	return "unix" | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *UnixSock) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
|  | ||||
| 	var err error | ||||
| 	_ = os.Remove(l.address) | ||||
| 	l.listen, err = net.Listen("unix", l.address) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Serve starts waiting for new UnixSock connections, and calls the establish | ||||
| // connection callback for any received. | ||||
| func (l *UnixSock) Serve(establish EstablishFn) { | ||||
| 	for { | ||||
| 		if atomic.LoadUint32(&l.end) == 1 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		conn, err := l.listen.Accept() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if atomic.LoadUint32(&l.end) == 0 { | ||||
| 			go func() { | ||||
| 				err = establish(l.id, conn) | ||||
| 				if err != nil { | ||||
| 					l.log.Warn().Err(err).Send() | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *UnixSock) Close(closeClients CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		closeClients(l.id) | ||||
| 	} | ||||
|  | ||||
| 	if l.listen != nil { | ||||
| 		err := l.listen.Close() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										96
									
								
								listeners/unixsock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								listeners/unixsock_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: jason@zgwit.com | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| const testUnixAddr = "mochi.sock" | ||||
|  | ||||
| func TestNewUnixSock(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| 	require.Equal(t, testUnixAddr, l.address) | ||||
| } | ||||
|  | ||||
| func TestUnixSockID(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestUnixSockAddress(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	require.Equal(t, testUnixAddr, l.Address()) | ||||
| } | ||||
|  | ||||
| func TestUnixSockProtocol(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	require.Equal(t, "unix", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestUnixSockInit(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	err := l.Init(&logger) | ||||
| 	l.Close(MockCloser) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l2 := NewUnixSock("t2", testUnixAddr) | ||||
| 	err = l2.Init(&logger) | ||||
| 	l2.Close(MockCloser) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestUnixSockServeAndClose(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	err := l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.True(t, closed) | ||||
| 	<-o | ||||
|  | ||||
| 	l.Close(MockCloser)      // coverage: close closed | ||||
| 	l.Serve(MockEstablisher) // coverage: serve closed | ||||
| } | ||||
|  | ||||
| func TestUnixSockEstablishThenEnd(t *testing.T) { | ||||
| 	l := NewUnixSock("t1", testUnixAddr) | ||||
| 	err := l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	established := make(chan bool) | ||||
| 	go func() { | ||||
| 		l.Serve(func(id string, c net.Conn) error { | ||||
| 			established <- true | ||||
| 			return errors.New("ending") // return an error to exit immediately | ||||
| 		}) | ||||
| 		o <- true | ||||
| 	}() | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	net.Dial("unix", l.listen.Addr().String()) | ||||
| 	require.Equal(t, true, <-established) | ||||
| 	l.Close(MockCloser) | ||||
| 	<-o | ||||
| } | ||||
| @@ -154,7 +154,7 @@ func (ws *wsConn) Read(p []byte) (int, error) { | ||||
| 		br, err = r.Read(p[n:]) | ||||
| 		n += br | ||||
| 		if err != nil { | ||||
| 			if err == io.EOF { | ||||
| 			if errors.Is(err, io.EOF) { | ||||
| 				err = nil | ||||
| 			} | ||||
| 			return n, err | ||||
|   | ||||
| @@ -376,7 +376,7 @@ func TestEncodeUint16(t *testing.T) { | ||||
| 	result = encodeUint16(32767) | ||||
| 	require.Equal(t, []byte{0x7f, 0xff}, result) | ||||
|  | ||||
| 	result = encodeUint16(65535) | ||||
| 	result = encodeUint16(math.MaxUint16) | ||||
| 	require.Equal(t, []byte{0xff, 0xff}, result) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -21,7 +21,7 @@ func (c Code) Error() string { | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// QosCodes indicicates the reason codes for each Qos byte. | ||||
| 	// QosCodes indicates the reason codes for each Qos byte. | ||||
| 	QosCodes = map[byte]Code{ | ||||
| 		0: CodeGrantedQos0, | ||||
| 		1: CodeGrantedQos1, | ||||
| @@ -113,15 +113,35 @@ var ( | ||||
| 	ErrPacketTooLarge                         = Code{Code: 0x95, Reason: "packet too large"} | ||||
| 	ErrMessageRateTooHigh                     = Code{Code: 0x96, Reason: "message rate too high"} | ||||
| 	ErrQuotaExceeded                          = Code{Code: 0x97, Reason: "quota exceeded"} | ||||
| 	ErrPendingClientWritesExceeded            = Code{Code: 0x97, Reason: "too many pending writes"} | ||||
| 	ErrAdministrativeAction                   = Code{Code: 0x98, Reason: "administrative action"} | ||||
| 	ErrPayloadFormatInvalid                   = Code{Code: 0x99, Reason: "payload format invalid"} | ||||
| 	ErrRetainNotSupported                     = Code{Code: 0x9A, Reason: "retain not supported"} | ||||
| 	ErrQosNotSupported                        = Code{Code: 0x9B, Reason: "qos not supported"} | ||||
| 	ErrUseAnotherServer                       = Code{Code: 0x9C, Reason: "use another server"} | ||||
| 	ErrServerMoved                            = Code{Code: 0x9D, Reason: "server moved"} | ||||
| 	ErrSharedSubscriptionsNotSupported        = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"} | ||||
| 	ErrSharedSubscriptionsNotSupported        = Code{Code: 0x9E, Reason: "shared subscriptions not supported"} | ||||
| 	ErrConnectionRateExceeded                 = Code{Code: 0x9F, Reason: "connection rate exceeded"} | ||||
| 	ErrMaxConnectTime                         = Code{Code: 0xA0, Reason: "maximum connect time"} | ||||
| 	ErrSubscriptionIdentifiersNotSupported    = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} | ||||
| 	ErrWildcardSubscriptionsNotSupported      = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} | ||||
|  | ||||
| 	// MQTTv3 specific bytes. | ||||
| 	Err3UnsupportedProtocolVersion = Code{Code: 0x01} | ||||
| 	Err3ClientIdentifierNotValid   = Code{Code: 0x02} | ||||
| 	Err3ServerUnavailable          = Code{Code: 0x03} | ||||
| 	ErrMalformedUsernameOrPassword = Code{Code: 0x04} | ||||
| 	Err3NotAuthorized              = Code{Code: 0x05} | ||||
|  | ||||
| 	// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes. | ||||
| 	// This is required because MQTTv3 has different return byte specification. | ||||
| 	// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257 | ||||
| 	V5CodesToV3 = map[Code]Code{ | ||||
| 		ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion, | ||||
| 		ErrClientIdentifierNotValid:   Err3ClientIdentifierNotValid, | ||||
| 		ErrServerUnavailable:          Err3ServerUnavailable, | ||||
| 		ErrMalformedUsername:          ErrMalformedUsernameOrPassword, | ||||
| 		ErrMalformedPassword:          ErrMalformedUsernameOrPassword, | ||||
| 		ErrBadUsernameOrPassword:      Err3NotAuthorized, | ||||
| 	} | ||||
| ) | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| @@ -15,22 +16,23 @@ import ( | ||||
|  | ||||
| // All of the valid packet types and their packet identifier. | ||||
| const ( | ||||
| 	Reserved    byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. | ||||
| 	Connect                 // 1 | ||||
| 	Connack                 // 2 | ||||
| 	Publish                 // 3 | ||||
| 	Puback                  // 4 | ||||
| 	Pubrec                  // 5 | ||||
| 	Pubrel                  // 6 | ||||
| 	Pubcomp                 // 7 | ||||
| 	Subscribe               // 8 | ||||
| 	Suback                  // 9 | ||||
| 	Unsubscribe             // 10 | ||||
| 	Unsuback                // 11 | ||||
| 	Pingreq                 // 12 | ||||
| 	Pingresp                // 13 | ||||
| 	Disconnect              // 14 | ||||
| 	Auth                    // 15 | ||||
| 	Reserved       byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. | ||||
| 	Connect                    // 1 | ||||
| 	Connack                    // 2 | ||||
| 	Publish                    // 3 | ||||
| 	Puback                     // 4 | ||||
| 	Pubrec                     // 5 | ||||
| 	Pubrel                     // 6 | ||||
| 	Pubcomp                    // 7 | ||||
| 	Subscribe                  // 8 | ||||
| 	Suback                     // 9 | ||||
| 	Unsubscribe                // 10 | ||||
| 	Unsuback                   // 11 | ||||
| 	Pingreq                    // 12 | ||||
| 	Pingresp                   // 13 | ||||
| 	Disconnect                 // 14 | ||||
| 	Auth                       // 15 | ||||
| 	WillProperties byte = 99   // Special byte for validating Will Properties. | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -208,7 +210,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet { | ||||
| 		Created:        pk.Created, | ||||
| 		Expiry:         pk.Expiry, | ||||
| 		Origin:         pk.Origin, | ||||
| 		PacketID:       pk.PacketID, // ... ? Packet ID must not be transferred (in this manner) | ||||
| 	} | ||||
|  | ||||
| 	if allowTransfer { | ||||
| 		p.PacketID = pk.PacketID | ||||
| 	} | ||||
|  | ||||
| 	if len(pk.Connect.ProtocolName) > 0 { | ||||
| @@ -309,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		(&pk.Properties).Encode(pk, pb, 0) | ||||
| 		(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -318,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { | ||||
| 	if pk.Connect.WillFlag { | ||||
| 		if pk.ProtocolVersion == 5 { | ||||
| 			pb := bytes.NewBuffer([]byte{}) | ||||
| 			(&pk.Connect).WillProperties.Encode(pk, pb, 0) | ||||
| 			(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0) | ||||
| 			nb.Write(pb.Bytes()) | ||||
| 		} | ||||
|  | ||||
| @@ -379,7 +384,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error { | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
| 	} | ||||
|  | ||||
| 	pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4] | ||||
| @@ -389,11 +394,11 @@ func (pk *Packet) ConnectDecode(buf []byte) error { | ||||
|  | ||||
| 	if pk.Connect.WillFlag { // [MQTT-3.1.2-7] | ||||
| 		if pk.ProtocolVersion == 5 { | ||||
| 			n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) | ||||
| 			n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:])) | ||||
| 			if err != nil { | ||||
| 				return ErrMalformedWillProperties | ||||
| 			} | ||||
| 			offset += n + 1 | ||||
| 			offset += n | ||||
| 		} | ||||
|  | ||||
| 		pk.Connect.WillTopic, offset, err = decodeString(buf, offset) | ||||
| @@ -408,6 +413,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error { | ||||
| 	} | ||||
|  | ||||
| 	if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12] | ||||
| 		if offset >= len(buf) { // we are at the end of the packet | ||||
| 			return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17] | ||||
| 		} | ||||
|  | ||||
| 		pk.Connect.Username, offset, err = decodeBytes(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedUsername | ||||
| @@ -439,18 +448,14 @@ func (pk *Packet) ConnectValidate() Code { | ||||
| 		return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3] | ||||
| 	} | ||||
|  | ||||
| 	if len(pk.Connect.Password) > 65535 { | ||||
| 	if len(pk.Connect.Password) > math.MaxUint16 { | ||||
| 		return ErrProtocolViolationPasswordTooLong | ||||
| 	} | ||||
|  | ||||
| 	if len(pk.Connect.Username) > 65535 { | ||||
| 	if len(pk.Connect.Username) > math.MaxUint16 { | ||||
| 		return ErrProtocolViolationUsernameTooLong | ||||
| 	} | ||||
|  | ||||
| 	if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 { | ||||
| 		return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17] | ||||
| 	} | ||||
|  | ||||
| 	if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 { | ||||
| 		return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16] | ||||
| 	} | ||||
| @@ -463,7 +468,7 @@ func (pk *Packet) ConnectValidate() Code { | ||||
| 		return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18] | ||||
| 	} | ||||
|  | ||||
| 	if len(pk.Connect.ClientIdentifier) > 65535 { | ||||
| 	if len(pk.Connect.ClientIdentifier) > math.MaxUint16 { | ||||
| 		return ErrClientIdentifierNotValid | ||||
| 	} | ||||
|  | ||||
| @@ -492,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -535,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { | ||||
| 		nb.WriteByte(pk.ReasonCode) | ||||
|  | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -604,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload)) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload)) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -640,7 +645,7 @@ func (pk *Packet) PublishDecode(buf []byte) error { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
|  | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
| 	} | ||||
|  | ||||
| 	pk.Payload = buf[offset:] | ||||
| @@ -688,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) | ||||
| 		if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 { | ||||
| 			nb.WriteByte(pk.ReasonCode) | ||||
| 		} | ||||
| @@ -829,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes)) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes)) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -857,7 +862,7 @@ func (pk *Packet) SubackDecode(buf []byte) error { | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
| 	} | ||||
|  | ||||
| 	pk.ReasonCodes = buf[offset:] | ||||
| @@ -886,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()+xb.Len()) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -914,7 +919,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error { | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
| 	} | ||||
|  | ||||
| 	var filter string | ||||
| @@ -981,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -1010,7 +1015,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
|  | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
|  | ||||
| 		pk.ReasonCodes = buf[offset:] | ||||
| 	} | ||||
| @@ -1034,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	if pk.ProtocolVersion == 5 { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		pk.Properties.Encode(pk, pb, nb.Len()+xb.Len()) | ||||
| 		pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) | ||||
| 		nb.Write(pb.Bytes()) | ||||
| 	} | ||||
|  | ||||
| @@ -1062,7 +1067,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error { | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("%s: %w", err, ErrMalformedProperties) | ||||
| 		} | ||||
| 		offset += n + 1 | ||||
| 		offset += n | ||||
| 	} | ||||
|  | ||||
| 	var filter string | ||||
| @@ -1097,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { | ||||
| 	nb.WriteByte(pk.ReasonCode) | ||||
|  | ||||
| 	pb := bytes.NewBuffer([]byte{}) | ||||
| 	pk.Properties.Encode(pk, pb, nb.Len()) | ||||
| 	pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) | ||||
| 	nb.Write(pb.Bytes()) | ||||
|  | ||||
| 	pk.FixedHeader.Remaining = nb.Len() | ||||
|   | ||||
| @@ -464,6 +464,9 @@ func TestCopy(t *testing.T) { | ||||
| 		require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.EqualValues(t, pkc.Properties, tt.Packet.Properties) | ||||
|  | ||||
| 		pkcc := tt.Packet.Copy(false) | ||||
| 		require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -42,11 +42,11 @@ const ( | ||||
|  | ||||
| // validPacketProperties indicates which properties are valid for which packet types. | ||||
| var validPacketProperties = map[byte]map[byte]byte{ | ||||
| 	PropPayloadFormat:          {Publish: 1}, | ||||
| 	PropMessageExpiryInterval:  {Publish: 1}, | ||||
| 	PropContentType:            {Publish: 1}, | ||||
| 	PropResponseTopic:          {Publish: 1}, | ||||
| 	PropCorrelationData:        {Publish: 1}, | ||||
| 	PropPayloadFormat:          {Publish: 1, WillProperties: 1}, | ||||
| 	PropMessageExpiryInterval:  {Publish: 1, WillProperties: 1}, | ||||
| 	PropContentType:            {Publish: 1, WillProperties: 1}, | ||||
| 	PropResponseTopic:          {Publish: 1, WillProperties: 1}, | ||||
| 	PropCorrelationData:        {Publish: 1, WillProperties: 1}, | ||||
| 	PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1}, | ||||
| 	PropSessionExpiryInterval:  {Connect: 1, Connack: 1, Disconnect: 1}, | ||||
| 	PropAssignedClientID:       {Connack: 1}, | ||||
| @@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{ | ||||
| 	PropAuthenticationMethod:   {Connect: 1, Connack: 1, Auth: 1}, | ||||
| 	PropAuthenticationData:     {Connect: 1, Connack: 1, Auth: 1}, | ||||
| 	PropRequestProblemInfo:     {Connect: 1}, | ||||
| 	PropWillDelayInterval:      {Connect: 1}, | ||||
| 	PropWillDelayInterval:      {WillProperties: 1}, | ||||
| 	PropRequestResponseInfo:    {Connect: 1}, | ||||
| 	PropResponseInfo:           {Connack: 1}, | ||||
| 	PropServerReference:        {Connack: 1, Disconnect: 1}, | ||||
| @@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{ | ||||
| 	PropTopicAlias:             {Publish: 1}, | ||||
| 	PropMaximumQos:             {Connack: 1}, | ||||
| 	PropRetainAvailable:        {Connack: 1}, | ||||
| 	PropUser:                   {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1}, | ||||
| 	PropUser:                   {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1}, | ||||
| 	PropMaximumPacketSize:      {Connect: 1, Connack: 1}, | ||||
| 	PropWildcardSubAvailable:   {Connack: 1}, | ||||
| 	PropSubIDAvailable:         {Connack: 1}, | ||||
| @@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool { | ||||
| } | ||||
|  | ||||
| // Encode encodes properties into a bytes buffer. | ||||
| func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) { | ||||
| 	if p == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	pkt := pk.FixedHeader.Type | ||||
|  | ||||
| 	if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag { | ||||
| 		buf.WriteByte(PropPayloadFormat) | ||||
| 		buf.WriteByte(p.PayloadFormat) | ||||
| @@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| 		buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && //  [MQTT-3.3.2-14] | ||||
| 	if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && //  [MQTT-3.3.2-14] | ||||
| 		p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropResponseTopic) | ||||
| 		buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28] | ||||
| 	if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropCorrelationData) | ||||
| 		buf.Write(encodeBytes(p.CorrelationData)) | ||||
| 	} | ||||
| @@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| 		buf.WriteByte(p.RequestResponseInfo) | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28] | ||||
| 	if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropResponseInfo) | ||||
| 		buf.Write(encodeString(p.ResponseInfo)) | ||||
| 	} | ||||
| @@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
|  | ||||
| 	// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2] | ||||
| 	// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2] | ||||
| 	if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" { | ||||
| 	if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" { | ||||
| 		b := encodeString(p.ReasonString) | ||||
| 		if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize { | ||||
| 		if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize { | ||||
| 			buf.WriteByte(PropReasonString) | ||||
| 			buf.Write(b) | ||||
| 		} | ||||
| @@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| 		buf.WriteByte(p.RetainAvailable) | ||||
| 	} | ||||
|  | ||||
| 	if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) { | ||||
| 	if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		for _, v := range p.User { | ||||
| 			pb.WriteByte(PropUser) | ||||
| @@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| 		} | ||||
| 		// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3] | ||||
| 		// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3] | ||||
| 		if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize { | ||||
| 		if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize { | ||||
| 			buf.Write(pb.Bytes()) | ||||
| 		} | ||||
| 	} | ||||
| @@ -361,18 +359,19 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| } | ||||
|  | ||||
| // Decode decodes property bytes into a properties struct. | ||||
| func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
| func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) { | ||||
| 	if p == nil { | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	n, _, err = DecodeLength(b) | ||||
| 	var bu int | ||||
| 	n, bu, err = DecodeLength(b) | ||||
| 	if err != nil { | ||||
| 		return n, err | ||||
| 		return n + bu, err | ||||
| 	} | ||||
|  | ||||
| 	if n == 0 { | ||||
| 		return n, nil | ||||
| 		return n + bu, nil | ||||
| 	} | ||||
|  | ||||
| 	bt := b.Bytes() | ||||
| @@ -380,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
| 	for offset := 0; offset < n; { | ||||
| 		k, offset, err = decodeByte(bt, offset) | ||||
| 		if err != nil { | ||||
| 			return n, err | ||||
| 			return n + bu, err | ||||
| 		} | ||||
|  | ||||
| 		if _, ok := validPacketProperties[k][pk]; !ok { | ||||
| 			return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty) | ||||
| 		if _, ok := validPacketProperties[k][pkt]; !ok { | ||||
| 			return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty) | ||||
| 		} | ||||
|  | ||||
| 		switch k { | ||||
| @@ -406,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
|  | ||||
| 			n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:])) | ||||
| 			if err != nil { | ||||
| 				return n, err | ||||
| 				return n + bu, err | ||||
| 			} | ||||
| 			p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n) | ||||
| 			offset += bu | ||||
| @@ -452,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
| 			var k, v string | ||||
| 			k, offset, err = decodeString(bt, offset) | ||||
| 			if err != nil { | ||||
| 				return n, err | ||||
| 				return n + bu, err | ||||
| 			} | ||||
| 			v, offset, err = decodeString(bt, offset) | ||||
| 			p.User = append(p.User, UserProperty{Key: k, Val: v}) | ||||
| @@ -470,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return n, err | ||||
| 			return n + bu, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return n, nil | ||||
| 	return n + bu, nil | ||||
| } | ||||
|   | ||||
| @@ -202,14 +202,14 @@ func init() { | ||||
| func TestEncodeProperties(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) | ||||
| 	props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) | ||||
| 	require.Equal(t, propertiesBytes, b.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0) | ||||
| 	props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0) | ||||
| 	require.NotEqual(t, propertiesBytes, b.Bytes()) | ||||
| 	require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6})) | ||||
| 	require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5})) | ||||
| @@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { | ||||
| func TestEncodePropertiesDisallowResponseInfo(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0) | ||||
| 	props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0) | ||||
| 	require.NotEqual(t, propertiesBytes, b.Bytes()) | ||||
| 	require.NotContains(t, b.Bytes(), []byte{8, 0, 5}) | ||||
| 	require.NotContains(t, b.Bytes(), []byte{9, 0, 4}) | ||||
| @@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) { | ||||
|  | ||||
| 	pr := tmp{} | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0) | ||||
| 	pr.p.Encode(Reserved, Mods{}, b, 0) | ||||
| 	require.Equal(t, []byte{}, b.Bytes()) | ||||
| } | ||||
|  | ||||
| @@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) { | ||||
| 	// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero. | ||||
| 	props := new(Properties) | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) | ||||
| 	props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) | ||||
| 	require.Equal(t, []byte{0x00}, b.Bytes()) | ||||
| } | ||||
|  | ||||
| @@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) { | ||||
| 	props := new(Properties) | ||||
| 	n, err := props.Decode(Reserved, b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 172, n) | ||||
| 	require.Equal(t, 172+2, n) | ||||
| 	require.EqualValues(t, propertiesStruct, *props) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -71,7 +71,7 @@ const ( | ||||
| 	TConnectInvalidWillFlagNoPayload | ||||
| 	TConnectInvalidWillFlagQosOutOfRange | ||||
| 	TConnectInvalidWillSurplusRetain | ||||
| 	TConnectNotCleanNoClientID | ||||
| 	TConnectZeroByteUsername | ||||
| 	TConnectSpecInvalidUTF8D800 | ||||
| 	TConnectSpecInvalidUTF8DFFF | ||||
| 	TConnectSpecInvalidUTF80000 | ||||
| @@ -82,6 +82,7 @@ const ( | ||||
| 	TConnackAcceptedAdjustedExpiryInterval | ||||
| 	TConnackMinMqtt5 | ||||
| 	TConnackMinCleanMqtt5 | ||||
| 	TConnackServerKeepalive | ||||
| 	TConnackInvalidMinMqtt5 | ||||
| 	TConnackBadProtocolVersion | ||||
| 	TConnackProtocolViolationNoSession | ||||
| @@ -89,6 +90,7 @@ const ( | ||||
| 	TConnackServerUnavailable | ||||
| 	TConnackBadUsernamePassword | ||||
| 	TConnackBadUsernamePasswordNoSession | ||||
| 	TConnackMqtt5BadUsernamePasswordNoSession | ||||
| 	TConnackNotAuthorised | ||||
| 	TConnackMalSessionPresent | ||||
| 	TConnackMalReturnCode | ||||
| @@ -249,26 +251,26 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc:    "mqtt v3.1.1", | ||||
| 			Primary: true, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connect << 4, 16, // Fixed header | ||||
| 				Connect << 4, 15, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				4,     // Protocol Version | ||||
| 				0,     // Packet Flags | ||||
| 				0, 60, // Keepalive | ||||
| 				0, 4, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', '3', // Client ID "zen" | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connect, | ||||
| 					Remaining: 16, | ||||
| 					Remaining: 15, | ||||
| 				}, | ||||
| 				ProtocolVersion: 4, | ||||
| 				Connect: ConnectParams{ | ||||
| 					ProtocolName:     []byte("MQTT"), | ||||
| 					Clean:            false, | ||||
| 					Keepalive:        60, | ||||
| 					ClientIdentifier: "zen3", | ||||
| 					ClientIdentifier: "zen", | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -425,9 +427,9 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				Connect << 4, 28, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				4,     // Protocol Version | ||||
| 				194,   // Packet Flags | ||||
| 				0, 20, // Keepalive | ||||
| 				4,               // Protocol Version | ||||
| 				0 | 1<<6 | 1<<7, // Packet Flags | ||||
| 				0, 20,           // Keepalive | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 				0, 5, // Username MSB+LSB | ||||
| @@ -443,7 +445,7 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				ProtocolVersion: 4, | ||||
| 				Connect: ConnectParams{ | ||||
| 					ProtocolName:     []byte("MQTT"), | ||||
| 					Clean:            true, | ||||
| 					Clean:            false, | ||||
| 					Keepalive:        20, | ||||
| 					ClientIdentifier: "zen", | ||||
| 					UsernameFlag:     true, | ||||
| @@ -497,6 +499,43 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case:  TConnectZeroByteUsername, | ||||
| 			Desc:  "username flag but 0 byte username", | ||||
| 			Group: "decode", | ||||
| 			RawBytes: []byte{ | ||||
| 				Connect << 4, 23, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				5,     // Protocol Version | ||||
| 				130,   // Packet Flags | ||||
| 				0, 30, // Keepalive | ||||
| 				5,                // length | ||||
| 				17, 0, 0, 0, 120, // Session Expiry Interval (17) | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 				0, 0, // Username MSB+LSB | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connect, | ||||
| 					Remaining: 23, | ||||
| 				}, | ||||
| 				ProtocolVersion: 5, | ||||
| 				Connect: ConnectParams{ | ||||
| 					ProtocolName:     []byte("MQTT"), | ||||
| 					Clean:            true, | ||||
| 					Keepalive:        30, | ||||
| 					ClientIdentifier: "zen", | ||||
| 					Username:         []byte{}, | ||||
| 					UsernameFlag:     true, | ||||
| 				}, | ||||
| 				Properties: Properties{ | ||||
| 					SessionExpiryInterval:     uint32(120), | ||||
| 					SessionExpiryIntervalFlag: true, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
|  | ||||
| 		// Fail States | ||||
| 		{ | ||||
| @@ -623,6 +662,24 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				'm', 'o', 'c', | ||||
| 			}, | ||||
| 		}, | ||||
|  | ||||
| 		{ | ||||
| 			Case:      TConnectInvalidFlagNoUsername, | ||||
| 			Desc:      "username flag with no username bytes", | ||||
| 			Group:     "decode", | ||||
| 			FailFirst: ErrProtocolViolationFlagNoUsername, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connect << 4, 17, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				5,     // Protocol Version | ||||
| 				130,   // Flags | ||||
| 				0, 20, // Keepalive | ||||
| 				0, | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case:      TConnectMalPassword, | ||||
| 			Desc:      "malformed password", | ||||
| @@ -783,20 +840,6 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case:   TConnectInvalidFlagNoUsername, | ||||
| 			Desc:   "has username flag but no username", | ||||
| 			Group:  "validate", | ||||
| 			Expect: ErrProtocolViolationFlagNoUsername, | ||||
| 			Packet: &Packet{ | ||||
| 				FixedHeader:     FixedHeader{Type: Connect}, | ||||
| 				ProtocolVersion: 4, | ||||
| 				Connect: ConnectParams{ | ||||
| 					ProtocolName: []byte("MQTT"), | ||||
| 					UsernameFlag: true, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case:   TConnectInvalidUsernameNoFlag, | ||||
| 			Desc:   "has username but no flag", | ||||
| @@ -1043,25 +1086,22 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc:    "accepted, no session, adjusted expiry interval mqtt5", | ||||
| 			Primary: true, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 11, // fixed header | ||||
| 				Connack << 4, 8, // fixed header | ||||
| 				0, // Session present | ||||
| 				CodeSuccess.Code, | ||||
| 				8,                // length | ||||
| 				5,                // length | ||||
| 				17, 0, 0, 0, 120, // Session Expiry Interval (17) | ||||
| 				19, 0, 10, // Server Keep Alive (19) | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 11, | ||||
| 					Remaining: 8, | ||||
| 				}, | ||||
| 				ReasonCode: CodeSuccess.Code, | ||||
| 				Properties: Properties{ | ||||
| 					SessionExpiryInterval:     uint32(120), | ||||
| 					SessionExpiryIntervalFlag: true, | ||||
| 					ServerKeepAlive:           uint16(10), | ||||
| 					ServerKeepAliveFlag:       true, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -1148,28 +1188,25 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc:    "accepted min properties mqtt5", | ||||
| 			Primary: true, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 16, // fixed header | ||||
| 				Connack << 4, 13, // fixed header | ||||
| 				1, // existing session | ||||
| 				CodeSuccess.Code, | ||||
| 				13,                                // Properties length | ||||
| 				10,                                // Properties length | ||||
| 				18, 0, 5, 'm', 'o', 'c', 'h', 'i', // Assigned Client ID (18) | ||||
| 				19, 0, 20, // Server Keep Alive (19) | ||||
| 				36, 1, // Maximum Qos (36) | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 16, | ||||
| 					Remaining: 13, | ||||
| 				}, | ||||
| 				SessionPresent: true, | ||||
| 				ReasonCode:     CodeSuccess.Code, | ||||
| 				Properties: Properties{ | ||||
| 					ServerKeepAlive:     uint16(20), | ||||
| 					ServerKeepAliveFlag: true, | ||||
| 					AssignedClientID:    "mochi", | ||||
| 					MaximumQos:          byte(1), | ||||
| 					MaximumQosFlag:      true, | ||||
| 					AssignedClientID: "mochi", | ||||
| 					MaximumQos:       byte(1), | ||||
| 					MaximumQosFlag:   true, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -1178,11 +1215,10 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc:    "accepted min properties mqtt5b", | ||||
| 			Primary: true, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 6, // fixed header | ||||
| 				Connack << 4, 3, // fixed header | ||||
| 				0, // existing session | ||||
| 				CodeSuccess.Code, | ||||
| 				3,         // Properties length | ||||
| 				19, 0, 10, // server keepalive | ||||
| 				0, // Properties length | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| @@ -1192,6 +1228,27 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				}, | ||||
| 				SessionPresent: false, | ||||
| 				ReasonCode:     CodeSuccess.Code, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case:    TConnackServerKeepalive, | ||||
| 			Desc:    "server set keepalive", | ||||
| 			Primary: true, | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 6, // fixed header | ||||
| 				1, // existing session | ||||
| 				CodeSuccess.Code, | ||||
| 				3,         // Properties length | ||||
| 				19, 0, 10, // server keepalive | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 6, | ||||
| 				}, | ||||
| 				SessionPresent: true, | ||||
| 				ReasonCode:     CodeSuccess.Code, | ||||
| 				Properties: Properties{ | ||||
| 					ServerKeepAlive:     uint16(10), | ||||
| 					ServerKeepAliveFlag: true, | ||||
| @@ -1203,26 +1260,23 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc:    "failure min properties mqtt5", | ||||
| 			Primary: true, | ||||
| 			RawBytes: append([]byte{ | ||||
| 				Connack << 4, 26, // fixed header | ||||
| 				Connack << 4, 23, // fixed header | ||||
| 				0, // No existing session | ||||
| 				ErrUnspecifiedError.Code, | ||||
| 				// Properties | ||||
| 				23,        // length | ||||
| 				19, 0, 20, // Server Keep Alive (19) | ||||
| 				20,        // length | ||||
| 				31, 0, 17, // Reason String (31) | ||||
| 			}, []byte(ErrUnspecifiedError.Reason)...), | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 25, | ||||
| 					Remaining: 23, | ||||
| 				}, | ||||
| 				SessionPresent: false, | ||||
| 				ReasonCode:     ErrUnspecifiedError.Code, | ||||
| 				Properties: Properties{ | ||||
| 					ServerKeepAlive:     uint16(20), | ||||
| 					ServerKeepAliveFlag: true, | ||||
| 					ReasonString:        ErrUnspecifiedError.Reason, | ||||
| 					ReasonString: ErrUnspecifiedError.Reason, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| @@ -1316,10 +1370,28 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Desc: "bad username or password no session", | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 2, // fixed header | ||||
| 				0, // No session present | ||||
| 				ErrBadUsernameOrPassword.Code, | ||||
| 				0,                      // No session present | ||||
| 				Err3NotAuthorized.Code, // use v3 remapping | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 2, | ||||
| 				}, | ||||
| 				ReasonCode: Err3NotAuthorized.Code, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Case: TConnackMqtt5BadUsernamePasswordNoSession, | ||||
| 			Desc: "mqtt5 bad username or password no session", | ||||
| 			RawBytes: []byte{ | ||||
| 				Connack << 4, 3, // fixed header | ||||
| 				0, // No session present | ||||
| 				ErrBadUsernameOrPassword.Code, | ||||
| 				0, | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| 				ProtocolVersion: 5, | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Connack, | ||||
| 					Remaining: 2, | ||||
| @@ -1327,6 +1399,7 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				ReasonCode: ErrBadUsernameOrPassword.Code, | ||||
| 			}, | ||||
| 		}, | ||||
|  | ||||
| 		{ | ||||
| 			Case: TConnackNotAuthorised, | ||||
| 			Desc: "not authorised", | ||||
| @@ -1804,13 +1877,10 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 			Case: TPublishRetainMqtt5, | ||||
| 			Desc: "retain mqtt5", | ||||
| 			RawBytes: []byte{ | ||||
| 				Publish<<4 | 1<<0, 35, // Fixed header | ||||
| 				Publish<<4 | 1<<0, 19, // Fixed header | ||||
| 				0, 5, // Topic Name - LSB+MSB | ||||
| 				'a', '/', 'b', '/', 'c', // Topic Name | ||||
| 				16, // properties length | ||||
| 				38, // User Properties (38) | ||||
| 				0, 5, 'h', 'e', 'l', 'l', 'o', | ||||
| 				0, 6, 228, 184, 150, 231, 149, 140, | ||||
| 				0,                                                     // properties length | ||||
| 				'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload | ||||
| 			}, | ||||
| 			Packet: &Packet{ | ||||
| @@ -1818,18 +1888,11 @@ var TPacketData = map[byte]TPacketCases{ | ||||
| 				FixedHeader: FixedHeader{ | ||||
| 					Type:      Publish, | ||||
| 					Retain:    true, | ||||
| 					Remaining: 35, | ||||
| 					Remaining: 19, | ||||
| 				}, | ||||
| 				TopicName: "a/b/c", | ||||
| 				Properties: Properties{ | ||||
| 					User: []UserProperty{ | ||||
| 						{ | ||||
| 							Key: "hello", | ||||
| 							Val: "世界", | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Payload: []byte("hello mochi"), | ||||
| 				TopicName:  "a/b/c", | ||||
| 				Properties: Properties{}, | ||||
| 				Payload:    []byte("hello mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
|   | ||||
							
								
								
									
										341
									
								
								server.go
									
									
									
									
									
								
							
							
						
						
									
										341
									
								
								server.go
									
									
									
									
									
								
							| @@ -26,10 +26,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Version                        = "2.0.5" // the current server version. | ||||
| 	defaultSysTopicInterval int64  = 1       // the interval between $SYS topic publishes | ||||
| 	defaultFanPoolSize      uint64 = 32      // the number of concurrent workers in the pool | ||||
| 	defaultFanPoolQueueSize uint64 = 1024    // the capacity of each worker queue | ||||
| 	Version                       = "2.2.14" // the current server version. | ||||
| 	defaultSysTopicInterval int64 = 1        // the interval between $SYS topic publishes | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -45,8 +43,8 @@ var ( | ||||
| 		WildcardSubAvailable:         1,              // wildcard subscriptions are available | ||||
| 		SubIDAvailable:               1,              // subscription identifiers are available | ||||
| 		SharedSubAvailable:           1,              // shared subscriptions are available | ||||
| 		ServerKeepAlive:              10,             // default keepalive for clients | ||||
| 		MinimumProtocolVersion:       3,              // minimum supported mqtt version (3.0.0) | ||||
| 		MaximumClientWritesPending:   1024 * 8,       // maximum number of pending message writes for a client | ||||
| 	} | ||||
|  | ||||
| 	ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists. | ||||
| @@ -56,18 +54,19 @@ var ( | ||||
| // Capabilities indicates the capabilities and features provided by the server. | ||||
| type Capabilities struct { | ||||
| 	MaximumMessageExpiryInterval int64 | ||||
| 	MaximumClientWritesPending   int32 | ||||
| 	MaximumSessionExpiryInterval uint32 | ||||
| 	MaximumPacketSize            uint32 | ||||
| 	maximumPacketID              uint32 // unexported, used for testing only | ||||
| 	ReceiveMaximum               uint16 | ||||
| 	TopicAliasMaximum            uint16 | ||||
| 	ServerKeepAlive              uint16 | ||||
| 	SharedSubAvailable           byte | ||||
| 	MinimumProtocolVersion       byte | ||||
| 	Compatibilities              Compatibilities | ||||
| 	MaximumQos                   byte | ||||
| 	RetainAvailable              byte | ||||
| 	WildcardSubAvailable         byte | ||||
| 	SubIDAvailable               byte | ||||
| 	SharedSubAvailable           byte | ||||
| 	MinimumProtocolVersion       byte | ||||
| } | ||||
|  | ||||
| // Compatibilities provides flags for using compatibility modes. | ||||
| @@ -80,9 +79,17 @@ type Compatibilities struct { | ||||
|  | ||||
| // Options contains configurable options for the server. | ||||
| type Options struct { | ||||
| 	// Capabilities defines the server features and behaviour. | ||||
| 	// Capabilities defines the server features and behaviour. If you only wish to modify | ||||
| 	// several of these values, set them explicitly - e.g. | ||||
| 	// 	server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 | ||||
| 	Capabilities *Capabilities | ||||
|  | ||||
| 	// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. | ||||
| 	ClientNetWriteBufferSize int | ||||
|  | ||||
| 	// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. | ||||
| 	ClientNetReadBufferSize int | ||||
|  | ||||
| 	// Logger specifies a custom configured implementation of zerolog to override | ||||
| 	// the servers default logger configuration. If you wish to change the log level, | ||||
| 	// of the default logger, you can do so by setting | ||||
| @@ -91,16 +98,6 @@ type Options struct { | ||||
| 	// 	server.Log = &l | ||||
| 	Logger *zerolog.Logger | ||||
|  | ||||
| 	// FanPoolSize is the number of individual workers and queues to initialize. | ||||
| 	// Bigger is not necessarily better, and you should rely on defaults unless | ||||
| 	// you have know what you are doing. | ||||
| 	FanPoolSize uint64 | ||||
|  | ||||
| 	// FanPoolQueueSize is the size of the queue per worker. Increase this value | ||||
| 	// accordingly if you anticipate having intermittent but massive numbers of | ||||
| 	// messages. Cluster support is roadmapped. | ||||
| 	FanPoolQueueSize uint64 | ||||
|  | ||||
| 	// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. | ||||
| 	SysTopicResendInterval int64 | ||||
| } | ||||
| @@ -113,7 +110,6 @@ type Server struct { | ||||
| 	Clients   *Clients             // clients known to the broker | ||||
| 	Topics    *TopicsIndex         // an index of topic filter subscriptions and retained messages | ||||
| 	Info      *system.Info         // values about the server commonly known as $SYS topics | ||||
| 	fanpool   *FanPool             // a fixed size worker pool for processing inbound and outbound messages | ||||
| 	loop      *loop                // loop contains tickers for the system event loop | ||||
| 	done      chan bool            // indicate that the server is ending | ||||
| 	Log       *zerolog.Logger      // minimal no-alloc logger | ||||
| @@ -132,10 +128,10 @@ type loop struct { | ||||
|  | ||||
| // ops contains server values which can be propagated to other structs. | ||||
| type ops struct { | ||||
| 	capabilities *Capabilities   // a pointer to the server capabilities, for referencing in clients | ||||
| 	info         *system.Info    // pointers to server system info | ||||
| 	hooks        *Hooks          // pointer to the server hooks | ||||
| 	log          *zerolog.Logger // a structured logger for the client | ||||
| 	options *Options        // a pointer to the server options and capabilities, for referencing in clients | ||||
| 	info    *system.Info    // pointers to server system info | ||||
| 	hooks   *Hooks          // pointer to the server hooks | ||||
| 	log     *zerolog.Logger // a structured logger for the client | ||||
| } | ||||
|  | ||||
| // New returns a new instance of mochi mqtt broker. Optional parameters | ||||
| @@ -165,8 +161,7 @@ func New(opts *Options) *Server { | ||||
| 			Version: Version, | ||||
| 			Started: time.Now().Unix(), | ||||
| 		}, | ||||
| 		fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize), | ||||
| 		Log:     opts.Logger, | ||||
| 		Log: opts.Logger, | ||||
| 		hooks: &Hooks{ | ||||
| 			Log: opts.Logger, | ||||
| 		}, | ||||
| @@ -181,16 +176,18 @@ func (o *Options) ensureDefaults() { | ||||
| 		o.Capabilities = DefaultServerCapabilities | ||||
| 	} | ||||
|  | ||||
| 	o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535 | ||||
|  | ||||
| 	if o.SysTopicResendInterval == 0 { | ||||
| 		o.SysTopicResendInterval = defaultSysTopicInterval | ||||
| 	} | ||||
|  | ||||
| 	if o.FanPoolSize == 0 { | ||||
| 		o.FanPoolSize = defaultFanPoolSize | ||||
| 	if o.ClientNetWriteBufferSize == 0 { | ||||
| 		o.ClientNetWriteBufferSize = 1024 * 2 | ||||
| 	} | ||||
|  | ||||
| 	if o.FanPoolQueueSize < 1 { | ||||
| 		o.FanPoolQueueSize = defaultFanPoolQueueSize | ||||
| 	if o.ClientNetReadBufferSize == 0 { | ||||
| 		o.ClientNetReadBufferSize = 1024 * 2 | ||||
| 	} | ||||
|  | ||||
| 	if o.Logger == nil { | ||||
| @@ -205,10 +202,10 @@ func (o *Options) ensureDefaults() { | ||||
| // topic validation checks. | ||||
| func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client { | ||||
| 	cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit | ||||
| 		capabilities: s.Options.Capabilities, | ||||
| 		info:         s.Info, | ||||
| 		hooks:        s.hooks, | ||||
| 		log:          s.Log, | ||||
| 		options: s.Options, | ||||
| 		info:    s.Info, | ||||
| 		hooks:   s.hooks, | ||||
| 		log:     s.Log, | ||||
| 	}) | ||||
|  | ||||
| 	cl.ID = id | ||||
| @@ -216,9 +213,11 @@ func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) | ||||
|  | ||||
| 	if inline { // inline clients bypass acl and some validity checks. | ||||
| 		cl.Net.Inline = true | ||||
| 		// By default we don't want to restrict developer publishes, | ||||
| 		// By default, we don't want to restrict developer publishes, | ||||
| 		// but if you do, reset this after creating inline client. | ||||
| 		cl.State.Inflight.ResetReceiveQuota(math.MaxInt32) | ||||
| 	} else { | ||||
| 		go cl.WriteLoop() // can only write to real clients | ||||
| 	} | ||||
|  | ||||
| 	return cl | ||||
| @@ -323,16 +322,20 @@ func (s *Server) attachClient(cl *Client, listener string) error { | ||||
| 	cl.ParseConnect(listener, pk) | ||||
| 	code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2] | ||||
| 	if code != packets.CodeSuccess { | ||||
| 		if err := s.sendConnack(cl, code, false); err != nil { | ||||
| 		if err := s.SendConnack(cl, code, false, nil); err != nil { | ||||
| 			return fmt.Errorf("invalid connection send ack: %w", err) | ||||
| 		} | ||||
| 		return code // [MQTT-3.2.2-7] [MQTT-3.1.4-6] | ||||
| 	} | ||||
|  | ||||
| 	s.hooks.OnConnect(cl, pk) | ||||
| 	err = s.hooks.OnConnect(cl, pk) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.State.Keepalive) | ||||
| 	if !s.hooks.OnConnectAuthenticate(cl, pk) { // [MQTT-3.1.4-2] | ||||
| 		err := s.sendConnack(cl, packets.ErrBadUsernameOrPassword, false) | ||||
| 		err := s.SendConnack(cl, packets.ErrBadUsernameOrPassword, false, nil) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("invalid connection send ack: %w", err) | ||||
| 		} | ||||
| @@ -346,7 +349,7 @@ func (s *Server) attachClient(cl *Client, listener string) error { | ||||
| 	sessionPresent := s.inheritClientSession(pk, cl) | ||||
| 	s.Clients.Add(cl) // [MQTT-4.1.0-1] | ||||
|  | ||||
| 	err = s.sendConnack(cl, code, sessionPresent) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1] | ||||
| 	err = s.SendConnack(cl, code, sessionPresent, nil) // [MQTT-3.1.4-5] [MQTT-3.2.0-1] [MQTT-3.2.0-2] &[MQTT-3.14.0-1] | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("ack connection packet: %w", err) | ||||
| 	} | ||||
| @@ -366,18 +369,17 @@ func (s *Server) attachClient(cl *Client, listener string) error { | ||||
| 	if err != nil { | ||||
| 		s.sendLWT(cl) | ||||
| 		cl.Stop(err) | ||||
| 	} | ||||
|  | ||||
| 	if err == nil { | ||||
| 	} else { | ||||
| 		cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10] | ||||
| 	} | ||||
|  | ||||
| 	s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected") | ||||
| 	expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) | ||||
| 	expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean) | ||||
| 	s.hooks.OnDisconnect(cl, err, expire) | ||||
| 	if expire { | ||||
| 		s.unsubscribeClient(cl) | ||||
|  | ||||
| 	if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 { | ||||
| 		cl.ClearInflights(math.MaxInt64, 0) | ||||
| 		s.UnsubscribeClient(cl) | ||||
| 		s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23] | ||||
| 	} | ||||
|  | ||||
| @@ -451,25 +453,40 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code { | ||||
| // session is abandoned. | ||||
| func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { | ||||
| 	if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok { | ||||
| 		existing.Lock() | ||||
| 		defer existing.Unlock() | ||||
| 		s.DisconnectClient(existing, packets.ErrSessionTakenOver)                                 // [MQTT-3.1.4-3] | ||||
| 		if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] | ||||
| 			s.unsubscribeClient(existing) | ||||
| 		s.DisconnectClient(existing, packets.ErrSessionTakenOver)                                       // [MQTT-3.1.4-3] | ||||
| 		if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4] | ||||
| 			s.UnsubscribeClient(existing) | ||||
| 			existing.ClearInflights(math.MaxInt64, 0) | ||||
| 			return false // [MQTT-3.2.2-3] | ||||
| 			atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred | ||||
| 			return false                                       // [MQTT-3.2.2-3] | ||||
| 		} | ||||
|  | ||||
| 		atomic.StoreUint32(&existing.State.isTakenOver, 1) | ||||
| 		if existing.State.Inflight.Len() > 0 { | ||||
| 			cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5] | ||||
| 			if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 { | ||||
| 				cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client | ||||
| 				cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))            // client receive max | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5] | ||||
| 		for _, sub := range existing.State.Subscriptions.GetAll() { | ||||
| 			existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | ||||
| 			if !existed { | ||||
| 				atomic.AddInt64(&s.Info.Subscriptions, 1) | ||||
| 			} | ||||
| 			cl.State.Subscriptions.Add(sub.Filter, sub) | ||||
| 			s.publishRetainedToClient(cl, sub, existed) | ||||
| 		} | ||||
|  | ||||
| 		// Clean the state of the existing client to prevent sequential take-overs | ||||
| 		// from increasing memory usage by inflights + subs * client-id. | ||||
| 		s.UnsubscribeClient(existing) | ||||
| 		existing.ClearInflights(math.MaxInt64, 0) | ||||
| 		s.Log.Debug().Str("client", cl.ID). | ||||
| 			Str("old_remote", existing.Net.Remote). | ||||
| 			Str("new_remote", cl.Net.Remote). | ||||
| 			Msg("session taken over") | ||||
|  | ||||
| 		return true // [MQTT-3.2.2-3] | ||||
| 	} | ||||
|  | ||||
| @@ -480,15 +497,27 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool { | ||||
| 	return false // [MQTT-3.2.2-2] | ||||
| } | ||||
|  | ||||
| // sendConnack returns a Connack packet to a client. | ||||
| func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) error { | ||||
| 	properties := packets.Properties{ | ||||
| 		ServerKeepAlive:     s.Options.Capabilities.ServerKeepAlive, // [MQTT-3.1.2-21] | ||||
| 		ServerKeepAliveFlag: true, | ||||
| 		ReceiveMaximum:      s.Options.Capabilities.ReceiveMaximum, // 3.2.2.3.3 Receive Maximum | ||||
| // SendConnack returns a Connack packet to a client. | ||||
| func (s *Server) SendConnack(cl *Client, reason packets.Code, present bool, properties *packets.Properties) error { | ||||
| 	if properties == nil { | ||||
| 		properties = &packets.Properties{ | ||||
| 			ReceiveMaximum: s.Options.Capabilities.ReceiveMaximum, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	properties.ReceiveMaximum = s.Options.Capabilities.ReceiveMaximum // 3.2.2.3.3 Receive Maximum | ||||
| 	if cl.State.ServerKeepalive {                                     // You can set this dynamically using the OnConnect hook. | ||||
| 		properties.ServerKeepAlive = cl.State.Keepalive // [MQTT-3.1.2-21] | ||||
| 		properties.ServerKeepAliveFlag = true | ||||
| 	} | ||||
|  | ||||
| 	if reason.Code >= packets.ErrUnspecifiedError.Code { | ||||
| 		if cl.Properties.ProtocolVersion < 5 { | ||||
| 			if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes | ||||
| 				reason = v3reason | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		properties.ReasonString = reason.Reason | ||||
| 		ack := packets.Packet{ | ||||
| 			FixedHeader: packets.FixedHeader{ | ||||
| @@ -496,9 +525,8 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro | ||||
| 			}, | ||||
| 			SessionPresent: false,       // [MQTT-3.2.2-6] | ||||
| 			ReasonCode:     reason.Code, // [MQTT-3.2.2-8] | ||||
| 			Properties:     properties, | ||||
| 			Properties:     *properties, | ||||
| 		} | ||||
|  | ||||
| 		return cl.WritePacket(ack) | ||||
| 	} | ||||
|  | ||||
| @@ -518,14 +546,15 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro | ||||
| 		cl.Properties.Props.SessionExpiryIntervalFlag = true | ||||
| 	} | ||||
|  | ||||
| 	return cl.WritePacket(packets.Packet{ | ||||
| 	ack := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Type: packets.Connack, | ||||
| 		}, | ||||
| 		SessionPresent: present, | ||||
| 		ReasonCode:     reason.Code, // [MQTT-3.2.2-8] | ||||
| 		Properties:     properties, | ||||
| 	}) | ||||
| 		Properties:     *properties, | ||||
| 	} | ||||
| 	return cl.WritePacket(ack) | ||||
| } | ||||
|  | ||||
| // processPacket processes an inbound packet for a client. Since the method is | ||||
| @@ -588,7 +617,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error { | ||||
| 			if ok := cl.State.Inflight.Delete(next.PacketID); ok { | ||||
| 				atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 			} | ||||
| 			cl.State.Inflight.TakeSendQuota() | ||||
| 			cl.State.Inflight.DecreaseSendQuota() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -649,7 +678,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error { | ||||
|  | ||||
| // processPublish processes a Publish packet. | ||||
| func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||
| 	if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline { | ||||
| 	if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| @@ -657,20 +686,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||
| 		return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8] | ||||
| 	} | ||||
|  | ||||
| 	if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline { | ||||
| 	if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	pk.Origin = cl.ID | ||||
| 	pk.Created = time.Now().Unix() | ||||
|  | ||||
| 	if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok && !cl.Net.Inline { | ||||
| 		if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] | ||||
| 			ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) | ||||
| 			return cl.WritePacket(ack) | ||||
| 		} | ||||
| 		if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | ||||
| 			atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 	if !cl.Net.Inline { | ||||
| 		if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok { | ||||
| 			if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10] | ||||
| 				ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse) | ||||
| 				return cl.WritePacket(ack) | ||||
| 			} | ||||
| 			if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | ||||
| 				atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @@ -693,15 +724,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||
| 	} | ||||
|  | ||||
| 	if pk.FixedHeader.Qos == 0 { | ||||
| 		s.fanpool.Enqueue(cl.ID, func() { | ||||
| 			s.publishToSubscribers(pk) | ||||
| 		}) | ||||
|  | ||||
| 		s.publishToSubscribers(pk) | ||||
| 		s.hooks.OnPublished(cl, pk) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	cl.State.Inflight.TakeReceiveQuota() | ||||
| 	cl.State.Inflight.DecreaseReceiveQuota() | ||||
| 	ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4] | ||||
| 	if pk.FixedHeader.Qos == 2 { | ||||
| 		ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8] | ||||
| @@ -709,6 +737,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||
|  | ||||
| 	if ok := cl.State.Inflight.Set(ack); ok { | ||||
| 		atomic.AddInt64(&s.Info.Inflight, 1) | ||||
| 		s.hooks.OnQosPublish(cl, ack, ack.Created, 0) | ||||
| 	} | ||||
|  | ||||
| 	err := cl.WritePacket(ack) | ||||
| @@ -720,20 +749,18 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error { | ||||
| 		if ok := cl.State.Inflight.Delete(ack.PacketID); ok { | ||||
| 			atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 		} | ||||
| 		cl.State.Inflight.ReturnReceiveQuota() | ||||
| 		s.hooks.OnQosComplete(cl, pk) | ||||
| 		cl.State.Inflight.IncreaseReceiveQuota() | ||||
| 		s.hooks.OnQosComplete(cl, ack) | ||||
| 	} | ||||
|  | ||||
| 	s.fanpool.Enqueue(cl.ID, func() { | ||||
| 		s.publishToSubscribers(pk) | ||||
| 	}) | ||||
|  | ||||
| 	s.publishToSubscribers(pk) | ||||
| 	s.hooks.OnPublished(cl, pk) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // retainMessage adds a message to a topic, and if a persistent store is provided, | ||||
| // adds the message to the store so it can be reloaded if necessary. | ||||
| // adds the message to the store to be reloaded if necessary. | ||||
| func (s *Server) retainMessage(cl *Client, pk packets.Packet) { | ||||
| 	out := pk.Copy(false) | ||||
| 	r := s.Topics.RetainMessage(out) | ||||
| @@ -771,13 +798,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) { | ||||
| func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) { | ||||
| 	if sub.NoLocal && pk.Origin == cl.ID { | ||||
| 		return pk, nil // [MQTT-3.8.3-3] | ||||
| 	} | ||||
|  | ||||
| 	out = pk.Copy(false) | ||||
| 	if !sub.RetainAsPublished { // ![MQTT-3.3.1-13] | ||||
| 	out := pk.Copy(false) | ||||
| 	if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13] | ||||
| 		out.FixedHeader.Retain = false // [MQTT-3.3.1-12] | ||||
| 	} | ||||
|  | ||||
| @@ -807,6 +834,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet | ||||
| 	if out.FixedHeader.Qos > 0 { | ||||
| 		i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1] | ||||
| 		if err != nil { | ||||
| 			s.hooks.OnPacketIDExhausted(cl, pk) | ||||
| 			s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted") | ||||
| 			return out, packets.ErrQuotaExceeded | ||||
| 		} | ||||
| @@ -817,22 +845,32 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet | ||||
| 		if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3] | ||||
| 			atomic.AddInt64(&s.Info.Inflight, 1) | ||||
| 			s.hooks.OnQosPublish(cl, out, out.Created, 0) | ||||
| 			cl.State.Inflight.DecreaseSendQuota() | ||||
| 		} | ||||
|  | ||||
| 		if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 { | ||||
| 			out.Expiry = -1 | ||||
| 			cl.State.Inflight.Set(out) | ||||
| 			return pk, nil | ||||
| 			return out, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 		return pk, packets.CodeDisconnect | ||||
| 	if cl.Net.Conn == nil || cl.Closed() { | ||||
| 		return out, packets.CodeDisconnect | ||||
| 	} | ||||
|  | ||||
| 	cl.State.Inflight.TakeSendQuota() | ||||
| 	select { | ||||
| 	case cl.State.outbound <- &out: | ||||
| 		atomic.AddInt32(&cl.State.outboundQty, 1) | ||||
| 	default: | ||||
| 		atomic.AddInt64(&s.Info.MessagesDropped, 1) | ||||
| 		cl.ops.hooks.OnPublishDropped(cl, pk) | ||||
| 		cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight. | ||||
| 		cl.State.Inflight.IncreaseSendQuota() | ||||
| 		return out, packets.ErrPendingClientWritesExceeded | ||||
| 	} | ||||
|  | ||||
| 	return out, cl.WritePacket(out) | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) { | ||||
| @@ -847,12 +885,14 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e | ||||
| 	for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4] | ||||
| 		_, err := s.publishToClient(cl, sub, pkv) | ||||
| 		if err != nil { | ||||
| 			s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message") | ||||
| 			s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message") | ||||
| 			continue | ||||
| 		} | ||||
| 		s.hooks.OnRetainPublished(cl, pkv) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // buildAck builds an standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets. | ||||
| // buildAck builds a standardised ack message for Puback, Pubrec, Pubrel, Pubcomp packets. | ||||
| func (s *Server) buildAck(packetID uint16, pkt, qos byte, properties packets.Properties, reason packets.Code) packets.Packet { | ||||
| 	properties = packets.Properties{} // PRL | ||||
| 	if reason.Code >= packets.ErrUnspecifiedError.Code { | ||||
| @@ -881,7 +921,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error { | ||||
| 	} | ||||
|  | ||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5] | ||||
| 		cl.State.Inflight.ReturnSendQuota() | ||||
| 		cl.State.Inflight.IncreaseSendQuota() | ||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 		s.hooks.OnQosComplete(cl, pk) | ||||
| 	} | ||||
| @@ -904,7 +944,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error { | ||||
| 	} | ||||
|  | ||||
| 	ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6] | ||||
| 	cl.State.Inflight.TakeReceiveQuota()                                                  // -1 RECV QUOTA | ||||
| 	cl.State.Inflight.DecreaseReceiveQuota()                                              // -1 RECV QUOTA | ||||
| 	cl.State.Inflight.Set(ack)                                                            // [MQTT-4.3.3-5] | ||||
| 	return cl.WritePacket(ack) | ||||
| } | ||||
| @@ -931,8 +971,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	cl.State.Inflight.ReturnReceiveQuota()               // +1 RECV QUOTA | ||||
| 	cl.State.Inflight.ReturnSendQuota()                  // +1 SENT QUOTA | ||||
| 	cl.State.Inflight.IncreaseReceiveQuota()             // +1 RECV QUOTA | ||||
| 	cl.State.Inflight.IncreaseSendQuota()                // +1 SENT QUOTA | ||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12] | ||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 		s.hooks.OnQosComplete(cl, pk) | ||||
| @@ -944,8 +984,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error { | ||||
| // processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server. | ||||
| func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error { | ||||
| 	// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas. | ||||
| 	cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA | ||||
| 	cl.State.Inflight.ReturnSendQuota()    // +1 SENT QUOTA | ||||
| 	cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA | ||||
| 	cl.State.Inflight.IncreaseSendQuota()    // +1 SENT QUOTA | ||||
| 	if ok := cl.State.Inflight.Delete(pk.PacketID); ok { | ||||
| 		atomic.AddInt64(&s.Info.Inflight, -1) | ||||
| 		s.hooks.OnQosComplete(cl, pk) | ||||
| @@ -962,24 +1002,24 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | ||||
| 		code = packets.ErrPacketIdentifierInUse | ||||
| 	} | ||||
|  | ||||
| 	existed := false | ||||
| 	filterExisted := make([]bool, len(pk.Filters)) | ||||
| 	reasonCodes := make([]byte, len(pk.Filters)) | ||||
| 	for i, sub := range pk.Filters { | ||||
| 		if code != packets.CodeSuccess { | ||||
| 			reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91 | ||||
| 			continue | ||||
| 		} else if !IsValidFilter(sub.Filter, false) { | ||||
| 			reasonCodes[i] = packets.ErrTopicFilterInvalid.Code | ||||
| 		} else if sub.NoLocal && IsSharedFilter(sub.Filter) { | ||||
| 			reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4] | ||||
| 		} else if !s.hooks.OnACLCheck(cl, sub.Filter, false) { | ||||
| 			reasonCodes[i] = packets.ErrNotAuthorized.Code | ||||
| 			if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized { | ||||
| 				reasonCodes[i] = packets.ErrUnspecifiedError.Code | ||||
| 			} | ||||
| 		} else if !IsValidFilter(sub.Filter, false) { | ||||
| 			reasonCodes[i] = packets.ErrTopicFilterInvalid.Code | ||||
| 		} else if sub.NoLocal && IsSharedFilter(sub.Filter) { | ||||
| 			reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4] | ||||
| 		} else { | ||||
| 			existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | ||||
| 			if !existed { | ||||
| 			isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3] | ||||
| 			if isNew { | ||||
| 				atomic.AddInt64(&s.Info.Subscriptions, 1) | ||||
| 			} | ||||
| 			cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10] | ||||
| @@ -988,6 +1028,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | ||||
| 				sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] | ||||
| 			} | ||||
|  | ||||
| 			filterExisted[i] = !isNew | ||||
| 			reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7] | ||||
| 		} | ||||
|  | ||||
| @@ -996,7 +1037,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ack := packets.Packet{ //[MQTT-3.8.4-1] [MQTT-3.8.4-5] | ||||
| 	ack := packets.Packet{ // [MQTT-3.8.4-1] [MQTT-3.8.4-5] | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Type: packets.Suback, | ||||
| 		}, | ||||
| @@ -1022,7 +1063,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		s.publishRetainedToClient(cl, sub, existed) | ||||
| 		s.publishRetainedToClient(cl, sub, filterExisted[i]) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| @@ -1072,20 +1113,32 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error { | ||||
| 	return cl.WritePacket(ack) | ||||
| } | ||||
|  | ||||
| // unsubscribeClient unsubscribes a client from all of their subscriptions. | ||||
| func (s *Server) unsubscribeClient(cl *Client) { | ||||
| 	for k := range cl.State.Subscriptions.GetAll() { | ||||
| // UnsubscribeClient unsubscribes a client from all of their subscriptions. | ||||
| func (s *Server) UnsubscribeClient(cl *Client) { | ||||
| 	i := 0 | ||||
| 	filterMap := cl.State.Subscriptions.GetAll() | ||||
| 	filters := make([]packets.Subscription, len(filterMap)) | ||||
| 	for k := range filterMap { | ||||
| 		cl.State.Subscriptions.Delete(k) | ||||
| 	} | ||||
|  | ||||
| 	if atomic.LoadUint32(&cl.State.isTakenOver) == 1 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for k, v := range filterMap { | ||||
| 		if s.Topics.Unsubscribe(k, cl.ID) { | ||||
| 			atomic.AddInt64(&s.Info.Subscriptions, -1) | ||||
| 		} | ||||
| 		filters[i] = v | ||||
| 		i++ | ||||
| 	} | ||||
| 	s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters}) | ||||
| } | ||||
|  | ||||
| // processAuth processes an Auth packet. | ||||
| func (s *Server) processAuth(cl *Client, pk packets.Packet) error { | ||||
| 	_, err := s.hooks.OnAuthPacket(cl, pk) | ||||
| 	fmt.Println("err", err) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -1173,6 +1226,7 @@ func (s *Server) publishSysTopics() { | ||||
| 		SysPrefix + "/broker/packets/sent":         AtomicItoa(&s.Info.PacketsSent), | ||||
| 		SysPrefix + "/broker/messages/received":    AtomicItoa(&s.Info.MessagesReceived), | ||||
| 		SysPrefix + "/broker/messages/sent":        AtomicItoa(&s.Info.MessagesSent), | ||||
| 		SysPrefix + "/broker/messages/dropped":     AtomicItoa(&s.Info.MessagesDropped), | ||||
| 		SysPrefix + "/broker/messages/inflight":    AtomicItoa(&s.Info.Inflight), | ||||
| 		SysPrefix + "/broker/retained":             AtomicItoa(&s.Info.Retained), | ||||
| 		SysPrefix + "/broker/subscriptions":        AtomicItoa(&s.Info.Subscriptions), | ||||
| @@ -1190,12 +1244,10 @@ func (s *Server) publishSysTopics() { | ||||
| 	s.hooks.OnSysInfoTick(s.Info) | ||||
| } | ||||
|  | ||||
| // Close attempts to gracefully shutdown the server, all listeners, clients, and stores. | ||||
| // Close attempts to gracefully shut down the server, all listeners, clients, and stores. | ||||
| func (s *Server) Close() error { | ||||
| 	close(s.done) | ||||
| 	s.Listeners.CloseAll(s.closeListenerClients) | ||||
| 	s.fanpool.Close() | ||||
| 	s.fanpool.Wait() | ||||
| 	s.hooks.OnStopped() | ||||
| 	s.hooks.Stop() | ||||
|  | ||||
| @@ -1241,6 +1293,10 @@ func (s *Server) sendLWT(cl *Client) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if pk.FixedHeader.Retain { | ||||
| 		s.retainMessage(cl, pk) | ||||
| 	} | ||||
|  | ||||
| 	s.publishToSubscribers(pk)                      // [MQTT-3.1.2-8] | ||||
| 	atomic.StoreUint32(&cl.Properties.Will.Flag, 0) // [MQTT-3.1.2-10] | ||||
| 	s.hooks.OnWillSent(cl, pk) | ||||
| @@ -1315,6 +1371,7 @@ func (s *Server) loadServerInfo(v system.Info) { | ||||
| 		atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected) | ||||
| 		atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived) | ||||
| 		atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent) | ||||
| 		atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped) | ||||
| 		atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived) | ||||
| 		atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent) | ||||
| 		atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped) | ||||
| @@ -1372,25 +1429,7 @@ func (s *Server) loadClients(v []storage.Client) { | ||||
| func (s *Server) loadInflight(v []storage.Message) { | ||||
| 	for _, msg := range v { | ||||
| 		if client, ok := s.Clients.Get(msg.Origin); ok { | ||||
| 			client.State.Inflight.Set(packets.Packet{ | ||||
| 				FixedHeader: msg.FixedHeader, | ||||
| 				PacketID:    msg.PacketID, | ||||
| 				TopicName:   msg.TopicName, | ||||
| 				Payload:     msg.Payload, | ||||
| 				Origin:      msg.Origin, | ||||
| 				Created:     msg.Created, | ||||
| 				Properties: packets.Properties{ | ||||
| 					PayloadFormat:          msg.Properties.PayloadFormat, | ||||
| 					PayloadFormatFlag:      msg.Properties.PayloadFormatFlag, | ||||
| 					MessageExpiryInterval:  msg.Properties.MessageExpiryInterval, | ||||
| 					ContentType:            msg.Properties.ContentType, | ||||
| 					ResponseTopic:          msg.Properties.ResponseTopic, | ||||
| 					CorrelationData:        msg.Properties.CorrelationData, | ||||
| 					SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier, | ||||
| 					TopicAlias:             msg.Properties.TopicAlias, | ||||
| 					User:                   msg.Properties.User, | ||||
| 				}, | ||||
| 			}) | ||||
| 			client.State.Inflight.Set(msg.ToPacket()) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -1398,24 +1437,7 @@ func (s *Server) loadInflight(v []storage.Message) { | ||||
| // loadRetained restores retained messages from the datastore. | ||||
| func (s *Server) loadRetained(v []storage.Message) { | ||||
| 	for _, msg := range v { | ||||
| 		s.Topics.RetainMessage(packets.Packet{ | ||||
| 			FixedHeader: msg.FixedHeader, | ||||
| 			TopicName:   msg.TopicName, | ||||
| 			Payload:     msg.Payload, | ||||
| 			Origin:      msg.Origin, | ||||
| 			Created:     msg.Created, | ||||
| 			Properties: packets.Properties{ | ||||
| 				PayloadFormat:          msg.Properties.PayloadFormat, | ||||
| 				PayloadFormatFlag:      msg.Properties.PayloadFormatFlag, | ||||
| 				MessageExpiryInterval:  msg.Properties.MessageExpiryInterval, | ||||
| 				ContentType:            msg.Properties.ContentType, | ||||
| 				ResponseTopic:          msg.Properties.ResponseTopic, | ||||
| 				CorrelationData:        msg.Properties.CorrelationData, | ||||
| 				SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier, | ||||
| 				TopicAlias:             msg.Properties.TopicAlias, | ||||
| 				User:                   msg.Properties.User, | ||||
| 			}, | ||||
| 		}) | ||||
| 		s.Topics.RetainMessage(msg.ToPacket()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -1453,8 +1475,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) { | ||||
| // clearExpiredInflights deletes any inflight messages which have expired. | ||||
| func (s *Server) clearExpiredInflights(now int64) { | ||||
| 	for _, client := range s.Clients.GetAll() { | ||||
| 		if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 { | ||||
| 			s.hooks.OnExpireInflights(client, now) | ||||
| 		if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 { | ||||
| 			for _, id := range deleted { | ||||
| 				s.hooks.OnQosDropped(client, packets.Packet{PacketID: id}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -1465,6 +1489,9 @@ func (s *Server) sendDelayedLWT(dt int64) { | ||||
| 		if dt > pk.Expiry { | ||||
| 			s.publishToSubscribers(pk) // [MQTT-3.1.2-8] | ||||
| 			if cl, ok := s.Clients.Get(id); ok { | ||||
| 				if pk.FixedHeader.Retain { | ||||
| 					s.retainMessage(cl, pk) | ||||
| 				} | ||||
| 				cl.Properties.Will = Will{} // [MQTT-3.1.2-10] | ||||
| 				s.hooks.OnWillSent(cl, pk) | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										340
									
								
								server_test.go
									
									
									
									
									
								
							
							
						
						
									
										340
									
								
								server_test.go
									
									
									
									
									
								
							| @@ -48,16 +48,31 @@ func (h *AllowHook) Provides(b byte) bool { | ||||
| func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true } | ||||
| func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool     { return true } | ||||
|  | ||||
| type DelayHook struct { | ||||
| 	HookBase | ||||
| 	DisconnectDelay time.Duration | ||||
| } | ||||
|  | ||||
| func (h *DelayHook) ID() string { | ||||
| 	return "delay-hook" | ||||
| } | ||||
|  | ||||
| func (h *DelayHook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{OnDisconnect}, []byte{b}) | ||||
| } | ||||
|  | ||||
| func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) { | ||||
| 	time.Sleep(h.DisconnectDelay) | ||||
| } | ||||
|  | ||||
| func newServer() *Server { | ||||
| 	cc := *DefaultServerCapabilities | ||||
| 	cc.MaximumMessageExpiryInterval = 0 | ||||
| 	cc.ReceiveMaximum = 0 | ||||
|  | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Capabilities:     &cc, | ||||
| 		Logger:       &logger, | ||||
| 		Capabilities: &cc, | ||||
| 	}) | ||||
| 	s.AddHook(new(AllowHook), nil) | ||||
| 	return s | ||||
| @@ -68,8 +83,6 @@ func TestOptionsSetDefaults(t *testing.T) { | ||||
| 	opts.ensureDefaults() | ||||
|  | ||||
| 	require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval) | ||||
| 	require.Equal(t, defaultFanPoolSize, opts.FanPoolSize) | ||||
| 	require.Equal(t, defaultFanPoolQueueSize, opts.FanPoolQueueSize) | ||||
| 	require.Equal(t, DefaultServerCapabilities, opts.Capabilities) | ||||
|  | ||||
| 	opts = new(Options) | ||||
| @@ -86,7 +99,6 @@ func TestNew(t *testing.T) { | ||||
| 	require.NotNil(t, s.Info) | ||||
| 	require.NotNil(t, s.Log) | ||||
| 	require.NotNil(t, s.Options) | ||||
| 	require.NotNil(t, s.fanpool) | ||||
| 	require.NotNil(t, s.loop) | ||||
| 	require.NotNil(t, s.loop.sysTopics) | ||||
| 	require.NotNil(t, s.loop.inflightExpiry) | ||||
| @@ -115,9 +127,9 @@ func TestServerNewClient(t *testing.T) { | ||||
| 	require.NotNil(t, cl.State.Inflight.internal) | ||||
| 	require.NotNil(t, cl.State.Subscriptions) | ||||
| 	require.NotNil(t, cl.State.TopicAliases) | ||||
| 	require.Equal(t, defaultKeepalive, cl.State.keepalive) | ||||
| 	require.Equal(t, defaultKeepalive, cl.State.Keepalive) | ||||
| 	require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) | ||||
| 	require.NotNil(t, cl.Net.conn) | ||||
| 	require.NotNil(t, cl.Net.Conn) | ||||
| 	require.NotNil(t, cl.Net.bconn) | ||||
| 	require.NotNil(t, cl.ops) | ||||
| 	require.Equal(t, s.Log, cl.ops.log) | ||||
| @@ -406,20 +418,22 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { | ||||
|  | ||||
| 	cl, r0, _ := newTestClient() | ||||
| 	cl.Properties.ProtocolVersion = 5 | ||||
| 	cl.Properties.Username = []byte("mochi") | ||||
| 	cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier | ||||
| 	cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) | ||||
| 	cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) | ||||
|  | ||||
| 	cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) | ||||
| 	s.Clients.Add(cl) | ||||
|  | ||||
| 	r, w := net.Pipe() | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- s.EstablishConnection("tcp", r) | ||||
| 		err := s.EstablishConnection("tcp", r) | ||||
| 		o <- err | ||||
| 	}() | ||||
|  | ||||
| 	go func() { | ||||
| 		w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes) | ||||
| 		time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect. | ||||
| 		w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) | ||||
| 	}() | ||||
|  | ||||
| @@ -445,9 +459,14 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { | ||||
| 		require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect | ||||
| 	} | ||||
|  | ||||
| 	require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, <-recv) | ||||
| 	connackPlusPacket := append( | ||||
| 		packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes, | ||||
| 		packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes..., | ||||
| 	) | ||||
| 	require.Equal(t, connackPlusPacket, <-recv) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover) | ||||
|  | ||||
| 	time.Sleep(time.Microsecond * 100) | ||||
| 	w.Close() | ||||
| 	r.Close() | ||||
|  | ||||
| @@ -455,9 +474,99 @@ func TestEstablishConnectionInheritExisting(t *testing.T) { | ||||
| 	require.True(t, ok) | ||||
| 	require.NotEmpty(t, clw.State.Subscriptions) | ||||
|  | ||||
| 	sub, ok := cl.State.Subscriptions.Get("a/b/c") | ||||
| 	// Prevent sequential takeover memory-bloom. | ||||
| 	require.Empty(t, cl.State.Subscriptions.GetAll()) | ||||
| } | ||||
|  | ||||
| // See https://github.com/mochi-co/mqtt/issues/173 | ||||
| func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	d := new(DelayHook) | ||||
| 	d.DisconnectDelay = time.Millisecond * 200 | ||||
| 	s.AddHook(d, nil) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	// Clean session, 0 session expiry interval | ||||
| 	cl1RawBytes := []byte{ | ||||
| 		packets.Connect << 4, 21, // Fixed header | ||||
| 		0, 4, // Protocol Name - MSB+LSB | ||||
| 		'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 		5,      // Protocol Version | ||||
| 		1 << 1, // Packet Flags | ||||
| 		0, 30,  // Keepalive | ||||
| 		5,              // Properties length | ||||
| 		17, 0, 0, 0, 0, // Session Expiry Interval (17) | ||||
| 		0, 3, // Client ID - MSB+LSB | ||||
| 		'z', 'e', 'n', // Client ID "zen" | ||||
| 	} | ||||
|  | ||||
| 	// Make first connection | ||||
| 	r1, w1 := net.Pipe() | ||||
| 	o1 := make(chan error) | ||||
| 	go func() { | ||||
| 		err := s.EstablishConnection("tcp", r1) | ||||
| 		o1 <- err | ||||
| 	}() | ||||
| 	go func() { | ||||
| 		w1.Write(cl1RawBytes) | ||||
| 	}() | ||||
|  | ||||
| 	// receive the first connack | ||||
| 	recv := make(chan []byte) | ||||
| 	go func() { | ||||
| 		buf, err := io.ReadAll(w1) | ||||
| 		require.NoError(t, err) | ||||
| 		recv <- buf | ||||
| 	}() | ||||
|  | ||||
| 	// Get the first client pointer | ||||
| 	time.Sleep(time.Millisecond * 50) | ||||
| 	cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier) | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, packets.Subscription{Filter: "a/b/c", Qos: 1}, sub) | ||||
| 	cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) | ||||
| 	cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0}) | ||||
| 	time.Sleep(time.Millisecond * 50) | ||||
|  | ||||
| 	// Make the second connection | ||||
| 	r2, w2 := net.Pipe() | ||||
| 	o2 := make(chan error) | ||||
| 	go func() { | ||||
| 		err := s.EstablishConnection("tcp", r2) | ||||
| 		o2 <- err | ||||
| 	}() | ||||
| 	go func() { | ||||
| 		x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:] | ||||
| 		x[19] = '.' // differentiate username bytes in debugging | ||||
| 		w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes) | ||||
| 	}() | ||||
|  | ||||
| 	// receive the second connack | ||||
| 	recv2 := make(chan []byte) | ||||
| 	go func() { | ||||
| 		buf, err := io.ReadAll(w2) | ||||
| 		require.NoError(t, err) | ||||
| 		recv2 <- buf | ||||
| 	}() | ||||
|  | ||||
| 	// Capture first Client pointer | ||||
| 	clp1, ok := s.Clients.Get("zen") | ||||
| 	require.True(t, ok) | ||||
| 	require.Empty(t, clp1.Properties.Username) | ||||
| 	require.NotEmpty(t, clp1.State.Subscriptions.GetAll()) | ||||
|  | ||||
| 	err1 := <-o1 | ||||
| 	require.Error(t, err1) | ||||
| 	require.ErrorIs(t, err1, io.ErrClosedPipe) | ||||
|  | ||||
| 	// Capture second Client pointer | ||||
| 	clp2, ok := s.Clients.Get("zen") | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, []byte(".ochi"), clp2.Properties.Username) | ||||
| 	require.NotEmpty(t, clp2.State.Subscriptions.GetAll()) | ||||
| 	require.Empty(t, clp1.State.Subscriptions.GetAll()) | ||||
|  | ||||
| 	w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) | ||||
| 	require.NoError(t, <-o2) | ||||
| } | ||||
|  | ||||
| func TestEstablishConnectionResentPendingInflightsError(t *testing.T) { | ||||
| @@ -553,9 +662,7 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) { | ||||
|  | ||||
| func TestEstablishConnectionBadAuthentication(t *testing.T) { | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Logger: &logger, | ||||
| 	}) | ||||
| 	defer s.Close() | ||||
|  | ||||
| @@ -589,9 +696,7 @@ func TestEstablishConnectionBadAuthentication(t *testing.T) { | ||||
|  | ||||
| func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) { | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Logger: &logger, | ||||
| 	}) | ||||
| 	defer s.Close() | ||||
|  | ||||
| @@ -643,6 +748,33 @@ func TestServerEstablishConnectionInvalidConnect(t *testing.T) { | ||||
| 	r.Close() | ||||
| } | ||||
|  | ||||
| // See https://github.com/mochi-co/mqtt/issues/178 | ||||
| func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) { | ||||
| 	s := newServer() | ||||
|  | ||||
| 	r, w := net.Pipe() | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- s.EstablishConnection("tcp", r) | ||||
| 	}() | ||||
|  | ||||
| 	go func() { | ||||
| 		w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes) | ||||
| 		w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes) | ||||
| 	}() | ||||
|  | ||||
| 	// receive the connack error | ||||
| 	go func() { | ||||
| 		_, err := io.ReadAll(w) | ||||
| 		require.NoError(t, err) | ||||
| 	}() | ||||
|  | ||||
| 	err := <-o | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r.Close() | ||||
| } | ||||
|  | ||||
| func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) { | ||||
| 	s := newServer() | ||||
|  | ||||
| @@ -685,17 +817,40 @@ func TestServerEstablishConnectionBadPacket(t *testing.T) { | ||||
| 	r.Close() | ||||
| } | ||||
|  | ||||
| func TestServerEstablishConnectionOnConnectError(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	hook.fail = true | ||||
| 	err := s.AddHook(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, w := net.Pipe() | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- s.EstablishConnection("tcp", r) | ||||
| 	}() | ||||
|  | ||||
| 	go func() { | ||||
| 		w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) | ||||
| 	}() | ||||
|  | ||||
| 	err = <-o | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, errTestHook) | ||||
|  | ||||
| 	r.Close() | ||||
| } | ||||
|  | ||||
| func TestServerSendConnack(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, r, w := newTestClient() | ||||
| 	cl.Properties.ProtocolVersion = 5 | ||||
| 	s.Options.Capabilities.ServerKeepAlive = 20 | ||||
| 	s.Options.Capabilities.MaximumQos = 1 | ||||
| 	cl.Properties.Props = packets.Properties{ | ||||
| 		AssignedClientID: "mochi", | ||||
| 	} | ||||
| 	go func() { | ||||
| 		err := s.sendConnack(cl, packets.CodeSuccess, true) | ||||
| 		err := s.SendConnack(cl, packets.CodeSuccess, true, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		w.Close() | ||||
| 	}() | ||||
| @@ -709,9 +864,8 @@ func TestServerSendConnackFailureReason(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, r, w := newTestClient() | ||||
| 	cl.Properties.ProtocolVersion = 5 | ||||
| 	s.Options.Capabilities.ServerKeepAlive = 20 | ||||
| 	go func() { | ||||
| 		err := s.sendConnack(cl, packets.ErrUnspecifiedError, true) | ||||
| 		err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		w.Close() | ||||
| 	}() | ||||
| @@ -721,6 +875,23 @@ func TestServerSendConnackFailureReason(t *testing.T) { | ||||
| 	require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf) | ||||
| } | ||||
|  | ||||
| func TestServerSendConnackWithServerKeepalive(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, r, w := newTestClient() | ||||
| 	cl.Properties.ProtocolVersion = 5 | ||||
| 	cl.State.Keepalive = 10 | ||||
| 	cl.State.ServerKeepalive = true | ||||
| 	go func() { | ||||
| 		err := s.SendConnack(cl, packets.CodeSuccess, true, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		w.Close() | ||||
| 	}() | ||||
|  | ||||
| 	buf, err := io.ReadAll(r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf) | ||||
| } | ||||
|  | ||||
| func TestServerValidateConnect(t *testing.T) { | ||||
| 	packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet | ||||
| 	invalidBitPacket := packet | ||||
| @@ -790,7 +961,7 @@ func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) { | ||||
| 	cl.Properties.Props.SessionExpiryInterval = uint32(300) | ||||
| 	s.Options.Capabilities.MaximumSessionExpiryInterval = 120 | ||||
| 	go func() { | ||||
| 		err := s.sendConnack(cl, packets.CodeSuccess, false) | ||||
| 		err := s.SendConnack(cl, packets.CodeSuccess, false, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		w.Close() | ||||
| 	}() | ||||
| @@ -806,7 +977,7 @@ func TestInheritClientSession(t *testing.T) { | ||||
| 	n := time.Now().Unix() | ||||
|  | ||||
| 	existing, _, _ := newTestClient() | ||||
| 	existing.Net.conn = nil | ||||
| 	existing.Net.Conn = nil | ||||
| 	existing.ID = "mochi" | ||||
| 	existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1}) | ||||
| 	existing.State.Inflight = NewInflights() | ||||
| @@ -844,7 +1015,7 @@ func TestServerUnsubscribeClient(t *testing.T) { | ||||
| 	s.Topics.Subscribe(cl.ID, pk) | ||||
| 	subs := s.Topics.Subscribers("a/b/c") | ||||
| 	require.Equal(t, 1, len(subs.Subscriptions)) | ||||
| 	s.unsubscribeClient(cl) | ||||
| 	s.UnsubscribeClient(cl) | ||||
| 	subs = s.Topics.Subscribers("a/b/c") | ||||
| 	require.Equal(t, 0, len(subs.Subscriptions)) | ||||
| } | ||||
| @@ -1023,7 +1194,7 @@ func TestServerProcessPacketPublishAndReceive(t *testing.T) { | ||||
| 		w2.Close() | ||||
| 	}() | ||||
|  | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-receiverBuf) | ||||
| 	require.Equal(t, 1, len(s.Topics.Messages("a/b/c"))) | ||||
| } | ||||
|  | ||||
| @@ -1098,9 +1269,7 @@ func TestServerProcessPublishInvalidTopic(t *testing.T) { | ||||
|  | ||||
| func TestServerProcessPublishACLCheckDeny(t *testing.T) { | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Logger: &logger, | ||||
| 	}) | ||||
| 	s.Serve() | ||||
| 	defer s.Close() | ||||
| @@ -1383,6 +1552,7 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { | ||||
| 		pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet | ||||
| 		pkx.FixedHeader.Qos = 2 | ||||
| 		s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx) | ||||
| 		time.Sleep(time.Microsecond * 100) | ||||
| 		w.Close() | ||||
| 	}() | ||||
|  | ||||
| @@ -1396,6 +1566,33 @@ func TestPublishToClientServerDowngradeQos(t *testing.T) { | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf) | ||||
| } | ||||
|  | ||||
| func TestPublishToClientExceedClientWritesPending(t *testing.T) { | ||||
| 	s := newServer() | ||||
|  | ||||
| 	_, w := net.Pipe() | ||||
| 	cl := newClient(w, &ops{ | ||||
| 		info:  new(system.Info), | ||||
| 		hooks: new(Hooks), | ||||
| 		log:   &logger, | ||||
| 		options: &Options{ | ||||
| 			Capabilities: &Capabilities{ | ||||
| 				MaximumClientWritesPending: 3, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	s.Clients.Add(cl) | ||||
|  | ||||
| 	for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ { | ||||
| 		cl.State.outbound <- new(packets.Packet) | ||||
| 		atomic.AddInt32(&cl.State.outboundQty, 1) | ||||
| 	} | ||||
|  | ||||
| 	_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{}) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err) | ||||
| } | ||||
|  | ||||
| func TestPublishToClientServerTopicAlias(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, r, w := newTestClient() | ||||
| @@ -1407,6 +1604,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { | ||||
| 		pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet | ||||
| 		s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) | ||||
| 		s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx) | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		w.Close() | ||||
| 	}() | ||||
|  | ||||
| @@ -1428,7 +1626,7 @@ func TestPublishToClientServerTopicAlias(t *testing.T) { | ||||
| func TestPublishToClientExhaustedPacketID(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	for i := 0; i <= 65535; i++ { | ||||
| 	for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { | ||||
| 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) | ||||
| 	} | ||||
|  | ||||
| @@ -1440,7 +1638,7 @@ func TestPublishToClientExhaustedPacketID(t *testing.T) { | ||||
| func TestPublishToClientNoConn(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Net.conn = nil | ||||
| 	cl.Net.Conn = nil | ||||
|  | ||||
| 	_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) | ||||
| 	require.Error(t, err) | ||||
| @@ -1497,7 +1695,7 @@ func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl, r, w := newTestClient() | ||||
| 	s.Clients.Add(cl) | ||||
| 	for i := 0; i <= 65535; i++ { | ||||
| 	for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { | ||||
| 		cl.State.Inflight.Set(packets.Packet{PacketID: 1}) | ||||
| 	} | ||||
|  | ||||
| @@ -1537,7 +1735,7 @@ func TestPublishRetainedToClient(t *testing.T) { | ||||
| 	subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2}) | ||||
| 	require.True(t, subbed) | ||||
|  | ||||
| 	retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet) | ||||
| 	retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet) | ||||
| 	require.Equal(t, int64(1), retained) | ||||
|  | ||||
| 	go func() { | ||||
| @@ -1548,7 +1746,7 @@ func TestPublishRetainedToClient(t *testing.T) { | ||||
|  | ||||
| 	buf, err := io.ReadAll(r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf) | ||||
| } | ||||
|  | ||||
| func TestPublishRetainedToClientIsShared(t *testing.T) { | ||||
| @@ -1863,7 +2061,7 @@ func TestServerProcessInboundQos2Flow(t *testing.T) { | ||||
| 	for i, tx := range tt { | ||||
| 		t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) { | ||||
| 			r, w = net.Pipe() | ||||
| 			cl.Net.conn = w | ||||
| 			cl.Net.Conn = w | ||||
|  | ||||
| 			recv := make(chan []byte) | ||||
| 			go func() { // receive the ack | ||||
| @@ -1937,7 +2135,8 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { | ||||
| 	for i, tx := range tt { | ||||
| 		t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) { | ||||
| 			r, w := net.Pipe() | ||||
| 			cl.Net.conn = w | ||||
| 			time.Sleep(time.Millisecond) | ||||
| 			cl.Net.Conn = w | ||||
|  | ||||
| 			recv := make(chan []byte) | ||||
| 			go func() { // receive the ack | ||||
| @@ -1953,6 +2152,7 @@ func TestServerProcessOutboundQos2Flow(t *testing.T) { | ||||
| 				require.NoError(t, err) | ||||
| 			} | ||||
|  | ||||
| 			time.Sleep(time.Millisecond) | ||||
| 			w.Close() | ||||
|  | ||||
| 			if i != 2 { | ||||
| @@ -2064,7 +2264,7 @@ func TestServerProcessSubscribeWithRetain(t *testing.T) { | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, append( | ||||
| 		packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, | ||||
| 		packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes..., | ||||
| 		packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes..., | ||||
| 	), buf) | ||||
| } | ||||
|  | ||||
| @@ -2164,9 +2364,7 @@ func TestServerProcessSubscribeNoConnection(t *testing.T) { | ||||
|  | ||||
| func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Logger: &logger, | ||||
| 	}) | ||||
| 	s.Serve() | ||||
| 	cl, r, w := newTestClient() | ||||
| @@ -2185,9 +2383,7 @@ func TestServerProcessSubscribeACLCheckDeny(t *testing.T) { | ||||
|  | ||||
| func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) { | ||||
| 	s := New(&Options{ | ||||
| 		Logger:           &logger, | ||||
| 		FanPoolSize:      2, | ||||
| 		FanPoolQueueSize: 10, | ||||
| 		Logger: &logger, | ||||
| 	}) | ||||
| 	s.Serve() | ||||
| 	s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true | ||||
| @@ -2319,7 +2515,7 @@ func TestServerProcessPacketDisconnect(t *testing.T) { | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, 0, s.loop.willDelayed.Len()) | ||||
| 	require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done)) | ||||
| 	require.True(t, cl.Closed()) | ||||
| 	require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected)) | ||||
| } | ||||
|  | ||||
| @@ -2414,6 +2610,46 @@ func TestServerSendLWT(t *testing.T) { | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf) | ||||
| } | ||||
|  | ||||
| func TestServerSendLWTRetain(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	s.Serve() | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	sender, _, w1 := newTestClient() | ||||
| 	sender.ID = "sender" | ||||
| 	sender.Properties.Will = Will{ | ||||
| 		Flag:      1, | ||||
| 		TopicName: "a/b/c", | ||||
| 		Payload:   []byte("hello mochi"), | ||||
| 		Retain:    true, | ||||
| 	} | ||||
| 	s.Clients.Add(sender) | ||||
|  | ||||
| 	receiver, r2, w2 := newTestClient() | ||||
| 	receiver.ID = "receiver" | ||||
| 	s.Clients.Add(receiver) | ||||
| 	s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0}) | ||||
|  | ||||
| 	require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived)) | ||||
| 	require.Equal(t, 0, len(s.Topics.Messages("a/b/c"))) | ||||
|  | ||||
| 	receiverBuf := make(chan []byte) | ||||
| 	go func() { | ||||
| 		buf, err := io.ReadAll(r2) | ||||
| 		require.NoError(t, err) | ||||
| 		receiverBuf <- buf | ||||
| 	}() | ||||
|  | ||||
| 	go func() { | ||||
| 		s.sendLWT(sender) | ||||
| 		time.Sleep(time.Millisecond * 10) | ||||
| 		w1.Close() | ||||
| 		w2.Close() | ||||
| 	}() | ||||
|  | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-receiverBuf) | ||||
| } | ||||
|  | ||||
| func TestServerSendLWTDelayed(t *testing.T) { | ||||
| 	s := newServer() | ||||
| 	cl1, _, _ := newTestClient() | ||||
| @@ -2452,7 +2688,7 @@ func TestServerSendLWTDelayed(t *testing.T) { | ||||
| 		recv <- buf | ||||
| 	}() | ||||
|  | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, <-recv) | ||||
| } | ||||
|  | ||||
| func TestServerReadStore(t *testing.T) { | ||||
| @@ -2573,7 +2809,6 @@ func TestServerClose(t *testing.T) { | ||||
| 	err := s.AddListener(listeners.NewMockListener("t1", ":1882")) | ||||
| 	require.NoError(t, err) | ||||
| 	s.Serve() | ||||
| 	require.Equal(t, uint64(2), s.fanpool.Size()) | ||||
|  | ||||
| 	// receive the disconnect | ||||
| 	recv := make(chan []byte) | ||||
| @@ -2593,7 +2828,6 @@ func TestServerClose(t *testing.T) { | ||||
| 	s.Close() | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	require.Equal(t, false, listener.(*listeners.MockListener).IsServing()) | ||||
| 	require.Equal(t, uint64(0), s.fanpool.Size()) | ||||
| 	require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv) | ||||
| } | ||||
|  | ||||
| @@ -2651,7 +2885,7 @@ func TestServerClearExpiredClients(t *testing.T) { | ||||
| 	cl0, _, _ := newTestClient() | ||||
| 	cl0.ID = "c0" | ||||
| 	cl0.State.disconnected = n - 10 | ||||
| 	cl0.State.done = 1 | ||||
| 	cl0.State.cancelOpen() | ||||
| 	cl0.Properties.ProtocolVersion = 5 | ||||
| 	cl0.Properties.Props.SessionExpiryInterval = 12 | ||||
| 	cl0.Properties.Props.SessionExpiryIntervalFlag = true | ||||
| @@ -2661,7 +2895,7 @@ func TestServerClearExpiredClients(t *testing.T) { | ||||
| 	cl1, _, _ := newTestClient() | ||||
| 	cl1.ID = "c1" | ||||
| 	cl1.State.disconnected = n - 10 | ||||
| 	cl1.State.done = 1 | ||||
| 	cl1.State.cancelOpen() | ||||
| 	cl1.Properties.ProtocolVersion = 5 | ||||
| 	cl1.Properties.Props.SessionExpiryInterval = 8 | ||||
| 	cl1.Properties.Props.SessionExpiryIntervalFlag = true | ||||
| @@ -2671,7 +2905,7 @@ func TestServerClearExpiredClients(t *testing.T) { | ||||
| 	cl2, _, _ := newTestClient() | ||||
| 	cl2.ID = "c2" | ||||
| 	cl2.State.disconnected = n - 10 | ||||
| 	cl2.State.done = 1 | ||||
| 	cl2.State.cancelOpen() | ||||
| 	cl2.Properties.ProtocolVersion = 5 | ||||
| 	cl2.Properties.Props.SessionExpiryInterval = 0 | ||||
| 	cl2.Properties.Props.SessionExpiryIntervalFlag = true | ||||
|   | ||||
| @@ -4,6 +4,8 @@ | ||||
|  | ||||
| package system | ||||
|  | ||||
| import "sync/atomic" | ||||
|  | ||||
| // Info contains atomic counters and values for various server statistics | ||||
| // commonly found in $SYS topics (and others). | ||||
| // based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics | ||||
| @@ -20,6 +22,7 @@ type Info struct { | ||||
| 	ClientsTotal        int64  `json:"clients_total"`        // total number of connected and disconnected clients with a persistent session currently connected and registered | ||||
| 	MessagesReceived    int64  `json:"messages_received"`    // total number of publish messages received | ||||
| 	MessagesSent        int64  `json:"messages_sent"`        // total number of publish messages sent | ||||
| 	MessagesDropped     int64  `json:"messages_dropped"`     // total number of publish messages dropped to slow subscriber | ||||
| 	Retained            int64  `json:"retained"`             // total number of retained messages active on the broker | ||||
| 	Inflight            int64  `json:"inflight"`             // the number of messages currently in-flight | ||||
| 	InflightDropped     int64  `json:"inflight_dropped"`     // the number of inflight messages which were dropped | ||||
| @@ -29,3 +32,30 @@ type Info struct { | ||||
| 	MemoryAlloc         int64  `json:"memory_alloc"`         // memory currently allocated | ||||
| 	Threads             int64  `json:"threads"`              // number of active goroutines, named as threads for platform ambiguity | ||||
| } | ||||
|  | ||||
| // Clone makes a copy of Info using atomic operation | ||||
| func (i *Info) Clone() *Info { | ||||
| 	return &Info{ | ||||
| 		Version:             i.Version, | ||||
| 		Started:             atomic.LoadInt64(&i.Started), | ||||
| 		Time:                atomic.LoadInt64(&i.Time), | ||||
| 		Uptime:              atomic.LoadInt64(&i.Uptime), | ||||
| 		BytesReceived:       atomic.LoadInt64(&i.BytesReceived), | ||||
| 		BytesSent:           atomic.LoadInt64(&i.BytesSent), | ||||
| 		ClientsConnected:    atomic.LoadInt64(&i.ClientsConnected), | ||||
| 		ClientsMaximum:      atomic.LoadInt64(&i.ClientsMaximum), | ||||
| 		ClientsTotal:        atomic.LoadInt64(&i.ClientsTotal), | ||||
| 		ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected), | ||||
| 		MessagesReceived:    atomic.LoadInt64(&i.MessagesReceived), | ||||
| 		MessagesSent:        atomic.LoadInt64(&i.MessagesSent), | ||||
| 		MessagesDropped:     atomic.LoadInt64(&i.MessagesDropped), | ||||
| 		Retained:            atomic.LoadInt64(&i.Retained), | ||||
| 		Inflight:            atomic.LoadInt64(&i.Inflight), | ||||
| 		InflightDropped:     atomic.LoadInt64(&i.InflightDropped), | ||||
| 		Subscriptions:       atomic.LoadInt64(&i.Subscriptions), | ||||
| 		PacketsReceived:     atomic.LoadInt64(&i.PacketsReceived), | ||||
| 		PacketsSent:         atomic.LoadInt64(&i.PacketsSent), | ||||
| 		MemoryAlloc:         atomic.LoadInt64(&i.MemoryAlloc), | ||||
| 		Threads:             atomic.LoadInt64(&i.Threads), | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										37
									
								
								system/system_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								system/system_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package system | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestClone(t *testing.T) { | ||||
| 	o := &Info{ | ||||
| 		Version:             "version", | ||||
| 		Started:             1, | ||||
| 		Time:                2, | ||||
| 		Uptime:              3, | ||||
| 		BytesReceived:       4, | ||||
| 		BytesSent:           5, | ||||
| 		ClientsConnected:    6, | ||||
| 		ClientsMaximum:      7, | ||||
| 		ClientsTotal:        8, | ||||
| 		ClientsDisconnected: 9, | ||||
| 		MessagesReceived:    10, | ||||
| 		MessagesSent:        11, | ||||
| 		MessagesDropped:     20, | ||||
| 		Retained:            12, | ||||
| 		Inflight:            13, | ||||
| 		InflightDropped:     14, | ||||
| 		Subscriptions:       15, | ||||
| 		PacketsReceived:     16, | ||||
| 		PacketsSent:         17, | ||||
| 		MemoryAlloc:         18, | ||||
| 		Threads:             19, | ||||
| 	} | ||||
|  | ||||
| 	n := o.Clone() | ||||
|  | ||||
| 	require.Equal(t, o, n) | ||||
| } | ||||
							
								
								
									
										41
									
								
								topics.go
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								topics.go
									
									
									
									
									
								
							| @@ -301,6 +301,9 @@ func NewTopicsIndex() *TopicsIndex { | ||||
| // Subscribe adds a new subscription for a client to a topic filter, returning | ||||
| // true if the subscription was new. | ||||
| func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool { | ||||
| 	x.root.Lock() | ||||
| 	defer x.root.Unlock() | ||||
|  | ||||
| 	var existed bool | ||||
| 	prefix, _ := isolateParticle(subscription.Filter, 0) | ||||
| 	if strings.EqualFold(prefix, SharePrefix) { | ||||
| @@ -320,8 +323,13 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription | ||||
| // Unsubscribe removes a subscription filter for a client, returning true if the | ||||
| // subscription existed. | ||||
| func (x *TopicsIndex) Unsubscribe(filter, client string) bool { | ||||
| 	x.root.Lock() | ||||
| 	defer x.root.Unlock() | ||||
|  | ||||
| 	var d int | ||||
| 	if strings.HasPrefix(filter, SharePrefix) { | ||||
| 	prefix, _ := isolateParticle(filter, 0) | ||||
| 	shareSub := strings.EqualFold(prefix, SharePrefix) | ||||
| 	if shareSub { | ||||
| 		d = 2 | ||||
| 	} | ||||
|  | ||||
| @@ -330,8 +338,7 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	prefix, _ := isolateParticle(filter, 0) | ||||
| 	if strings.EqualFold(prefix, SharePrefix) { | ||||
| 	if shareSub { | ||||
| 		group, _ := isolateParticle(filter, 1) | ||||
| 		particle.shared.Delete(group, client) | ||||
| 	} else { | ||||
| @@ -346,7 +353,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool { | ||||
| // 1 if a retained message was added, and -1 if the retained message was removed. | ||||
| // 0 is returned if sequential empty payloads are received. | ||||
| func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 { | ||||
| 	x.root.Lock() | ||||
| 	defer x.root.Unlock() | ||||
|  | ||||
| 	n := x.set(pk.TopicName, 0) | ||||
| 	n.Lock() | ||||
| 	defer n.Unlock() | ||||
| 	if len(pk.Payload) > 0 { | ||||
| 		n.retainPath = pk.TopicName | ||||
| 		x.Retained.Add(pk.TopicName, pk) | ||||
| @@ -361,6 +373,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 { | ||||
| 	n.retainPath = "" | ||||
| 	x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7] | ||||
| 	x.trim(n) | ||||
|  | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| @@ -488,20 +501,27 @@ func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Su | ||||
| 	} | ||||
|  | ||||
| 	key, hasNext := isolateParticle(topic, d) | ||||
| 	for _, partKey := range []string{key, "+", "#"} { | ||||
| 	for _, partKey := range []string{key, "+"} { | ||||
| 		if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3] | ||||
| 			x.gatherSubscriptions(topic, particle, subs) | ||||
| 			x.gatherSharedSubscriptions(particle, subs) | ||||
| 			if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" { | ||||
| 				x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 | ||||
| 			} | ||||
|  | ||||
| 			if hasNext { | ||||
| 				x.scanSubscribers(topic, d+1, particle, subs) | ||||
| 			} else { | ||||
| 				x.gatherSubscriptions(topic, particle, subs) | ||||
| 				x.gatherSharedSubscriptions(particle, subs) | ||||
|  | ||||
| 				if wild := particle.particles.get("#"); wild != nil && partKey != "+" { | ||||
| 					x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2 | ||||
| 					x.gatherSharedSubscriptions(wild, subs) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if particle := n.particles.get("#"); particle != nil { | ||||
| 		x.gatherSubscriptions(topic, particle, subs) | ||||
| 		x.gatherSharedSubscriptions(particle, subs) | ||||
| 	} | ||||
|  | ||||
| 	return subs | ||||
| } | ||||
|  | ||||
| @@ -619,6 +639,7 @@ type particle struct { | ||||
| 	subscriptions *Subscriptions       // a map of subscriptions made by clients to this ending address | ||||
| 	shared        *SharedSubscriptions // a map of shared subscriptions keyed on group name | ||||
| 	retainPath    string               // path of a retained message | ||||
| 	sync.Mutex                         // mutex for when making changes to the particle | ||||
| } | ||||
|  | ||||
| // newParticle returns a pointer to a new instance of particle. | ||||
|   | ||||
| @@ -319,7 +319,7 @@ func TestUnsubscribeShared(t *testing.T) { | ||||
| 	require.True(t, exists) | ||||
| 	require.Equal(t, byte(2), client.Qos) | ||||
|  | ||||
| 	require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1")) | ||||
| 	require.True(t, index.Unsubscribe("$share/tmp/a/b/c", "cl1")) | ||||
| 	_, exists = final.shared.Get("tmp", "cl1") | ||||
| 	require.False(t, exists) | ||||
| } | ||||
| @@ -501,28 +501,40 @@ func TestScanSubscribers(t *testing.T) { | ||||
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2}) | ||||
|  | ||||
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) | ||||
| 	require.Equal(t, 4, len(subs.Subscriptions)) | ||||
| 	require.Equal(t, 3, len(subs.Subscriptions)) | ||||
| 	require.Contains(t, subs.Subscriptions, "cl1") | ||||
| 	require.Contains(t, subs.Subscriptions, "cl2") | ||||
| 	require.Contains(t, subs.Subscriptions, "cl3") | ||||
| 	require.Contains(t, subs.Subscriptions, "cl4") | ||||
|  | ||||
| 	require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos) | ||||
| 	require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos) | ||||
| 	require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos) | ||||
| 	require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos) | ||||
|  | ||||
| 	require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"]) | ||||
| 	require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"]) | ||||
| 	require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"]) | ||||
| 	require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"]) | ||||
| 	require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"]) | ||||
| 	require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"]) | ||||
|  | ||||
| 	subs = index.scanSubscribers("d/e/f/g", 0, nil, new(Subscribers)) | ||||
| 	require.Equal(t, 1, len(subs.Subscriptions)) | ||||
| 	require.Contains(t, subs.Subscriptions, "cl4") | ||||
| 	require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos) | ||||
| 	require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"]) | ||||
|  | ||||
| 	subs = index.scanSubscribers("", 0, nil, new(Subscribers)) | ||||
| 	require.Equal(t, 0, len(subs.Subscriptions)) | ||||
| } | ||||
|  | ||||
| func TestScanSubscribersTopicInheritanceBug(t *testing.T) { | ||||
| 	index := NewTopicsIndex() | ||||
| 	index.Subscribe("cl1", packets.Subscription{Qos: 0, Filter: "a/b/c"}) | ||||
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/b"}) | ||||
|  | ||||
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) | ||||
| 	require.Equal(t, 1, len(subs.Subscriptions)) | ||||
| } | ||||
|  | ||||
| func TestScanSubscribersShared(t *testing.T) { | ||||
| 	index := NewTopicsIndex() | ||||
| 	index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111}) | ||||
| @@ -531,8 +543,9 @@ func TestScanSubscribersShared(t *testing.T) { | ||||
| 	index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10}) | ||||
| 	index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200}) | ||||
| 	index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201}) | ||||
| 	index.Subscribe("cl5", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c/#"}) | ||||
| 	subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers)) | ||||
| 	require.Equal(t, 3, len(subs.Shared)) | ||||
| 	require.Equal(t, 4, len(subs.Shared)) | ||||
| } | ||||
|  | ||||
| func TestSelectSharedSubscriber(t *testing.T) { | ||||
|   | ||||
							
								
								
									
										2
									
								
								vendor/golang.org/x/net/trace/histogram.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								vendor/golang.org/x/net/trace/histogram.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -32,7 +32,7 @@ type histogram struct { | ||||
| 	valueCount   int64   // number of values recorded for single value | ||||
| } | ||||
|  | ||||
| // AddMeasurement records a value measurement observation to the histogram. | ||||
| // addMeasurement records a value measurement observation to the histogram. | ||||
| func (h *histogram) addMeasurement(value int64) { | ||||
| 	// TODO: assert invariant | ||||
| 	h.sum += value | ||||
|   | ||||
							
								
								
									
										2
									
								
								vendor/golang.org/x/net/trace/trace.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								vendor/golang.org/x/net/trace/trace.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -395,7 +395,7 @@ func New(family, title string) Trace { | ||||
| } | ||||
|  | ||||
| func (tr *trace) Finish() { | ||||
| 	elapsed := time.Now().Sub(tr.Start) | ||||
| 	elapsed := time.Since(tr.Start) | ||||
| 	tr.mu.Lock() | ||||
| 	tr.Elapsed = elapsed | ||||
| 	tr.mu.Unlock() | ||||
|   | ||||
							
								
								
									
										30
									
								
								vendor/golang.org/x/sys/internal/unsafeheader/unsafeheader.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								vendor/golang.org/x/sys/internal/unsafeheader/unsafeheader.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -1,30 +0,0 @@ | ||||
| // Copyright 2020 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| // Package unsafeheader contains header declarations for the Go runtime's | ||||
| // slice and string implementations. | ||||
| // | ||||
| // This package allows x/sys to use types equivalent to | ||||
| // reflect.SliceHeader and reflect.StringHeader without introducing | ||||
| // a dependency on the (relatively heavy) "reflect" package. | ||||
| package unsafeheader | ||||
|  | ||||
| import ( | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| // Slice is the runtime representation of a slice. | ||||
| // It cannot be used safely or portably and its representation may change in a later release. | ||||
| type Slice struct { | ||||
| 	Data unsafe.Pointer | ||||
| 	Len  int | ||||
| 	Cap  int | ||||
| } | ||||
|  | ||||
| // String is the runtime representation of a string. | ||||
| // It cannot be used safely or portably and its representation may change in a later release. | ||||
| type String struct { | ||||
| 	Data unsafe.Pointer | ||||
| 	Len  int | ||||
| } | ||||
							
								
								
									
										31
									
								
								vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| // Copyright 2022 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build (darwin || freebsd || netbsd || openbsd) && gc | ||||
| // +build darwin freebsd netbsd openbsd | ||||
| // +build gc | ||||
|  | ||||
| #include "textflag.h" | ||||
|  | ||||
| // | ||||
| // System call support for ppc64, BSD | ||||
| // | ||||
|  | ||||
| // Just jump to package syscall's implementation for all these functions. | ||||
| // The runtime may know about them. | ||||
|  | ||||
| TEXT	·Syscall(SB),NOSPLIT,$0-56 | ||||
| 	JMP	syscall·Syscall(SB) | ||||
|  | ||||
| TEXT	·Syscall6(SB),NOSPLIT,$0-80 | ||||
| 	JMP	syscall·Syscall6(SB) | ||||
|  | ||||
| TEXT	·Syscall9(SB),NOSPLIT,$0-104 | ||||
| 	JMP	syscall·Syscall9(SB) | ||||
|  | ||||
| TEXT	·RawSyscall(SB),NOSPLIT,$0-56 | ||||
| 	JMP	syscall·RawSyscall(SB) | ||||
|  | ||||
| TEXT	·RawSyscall6(SB),NOSPLIT,$0-80 | ||||
| 	JMP	syscall·RawSyscall6(SB) | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/dirent.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/dirent.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,8 +2,8 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris | ||||
| // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris | ||||
| //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos | ||||
| // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos | ||||
|  | ||||
| package unix | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/gccgo.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/gccgo.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,8 +2,8 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build gccgo && !aix | ||||
| // +build gccgo,!aix | ||||
| //go:build gccgo && !aix && !hurd | ||||
| // +build gccgo,!aix,!hurd | ||||
|  | ||||
| package unix | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/gccgo_c.c
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/gccgo_c.c
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,8 +2,8 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| // +build gccgo | ||||
| // +build !aix | ||||
| //go:build gccgo && !aix && !hurd | ||||
| // +build gccgo,!aix,!hurd | ||||
|  | ||||
| #include <errno.h> | ||||
| #include <stdint.h> | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/ioctl.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/ioctl.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,8 +2,8 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris | ||||
| // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris | ||||
| //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris | ||||
| // +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris | ||||
|  | ||||
| package unix | ||||
|  | ||||
|   | ||||
							
								
								
									
										20
									
								
								vendor/golang.org/x/sys/unix/ioctl_linux.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										20
									
								
								vendor/golang.org/x/sys/unix/ioctl_linux.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -4,9 +4,7 @@ | ||||
|  | ||||
| package unix | ||||
|  | ||||
| import ( | ||||
| 	"unsafe" | ||||
| ) | ||||
| import "unsafe" | ||||
|  | ||||
| // IoctlRetInt performs an ioctl operation specified by req on a device | ||||
| // associated with opened file descriptor fd, and returns a non-negative | ||||
| @@ -217,3 +215,19 @@ func IoctlKCMAttach(fd int, info KCMAttach) error { | ||||
| func IoctlKCMUnattach(fd int, info KCMUnattach) error { | ||||
| 	return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info)) | ||||
| } | ||||
|  | ||||
| // IoctlLoopGetStatus64 gets the status of the loop device associated with the | ||||
| // file descriptor fd using the LOOP_GET_STATUS64 operation. | ||||
| func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) { | ||||
| 	var value LoopInfo64 | ||||
| 	if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &value, nil | ||||
| } | ||||
|  | ||||
| // IoctlLoopSetStatus64 sets the status of the loop device associated with the | ||||
| // file descriptor fd using the LOOP_SET_STATUS64 operation. | ||||
| func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error { | ||||
| 	return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value)) | ||||
| } | ||||
|   | ||||
							
								
								
									
										49
									
								
								vendor/golang.org/x/sys/unix/mkall.sh
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										49
									
								
								vendor/golang.org/x/sys/unix/mkall.sh
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -73,12 +73,12 @@ aix_ppc64) | ||||
| darwin_amd64) | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs" | ||||
| 	mkasm="go run mkasm_darwin.go" | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	;; | ||||
| darwin_arm64) | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs" | ||||
| 	mkasm="go run mkasm_darwin.go" | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	;; | ||||
| dragonfly_amd64) | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| @@ -142,42 +142,60 @@ netbsd_arm64) | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs" | ||||
| 	;; | ||||
| openbsd_386) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m32" | ||||
| 	mksyscall="go run mksyscall.go -l32 -openbsd" | ||||
| 	mksyscall="go run mksyscall.go -l32 -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'" | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs" | ||||
| 	;; | ||||
| openbsd_amd64) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mksyscall="go run mksyscall.go -openbsd" | ||||
| 	mksyscall="go run mksyscall.go -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'" | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs" | ||||
| 	;; | ||||
| openbsd_arm) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors" | ||||
| 	mksyscall="go run mksyscall.go -l32 -openbsd -arm" | ||||
| 	mksyscall="go run mksyscall.go -l32 -openbsd -arm -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'" | ||||
| 	# Let the type of C char be signed for making the bare syscall | ||||
| 	# API consistent across platforms. | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" | ||||
| 	;; | ||||
| openbsd_arm64) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mksyscall="go run mksyscall.go -openbsd" | ||||
| 	mksyscall="go run mksyscall.go -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'" | ||||
| 	# Let the type of C char be signed for making the bare syscall | ||||
| 	# API consistent across platforms. | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" | ||||
| 	;; | ||||
| openbsd_mips64) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mksyscall="go run mksyscall.go -openbsd" | ||||
| 	mksyscall="go run mksyscall.go -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	# Let the type of C char be signed for making the bare syscall | ||||
| 	# API consistent across platforms. | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" | ||||
| 	;; | ||||
| openbsd_ppc64) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mksyscall="go run mksyscall.go -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	# Let the type of C char be signed for making the bare syscall | ||||
| 	# API consistent across platforms. | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" | ||||
| 	;; | ||||
| openbsd_riscv64) | ||||
| 	mkasm="go run mkasm.go" | ||||
| 	mkerrors="$mkerrors -m64" | ||||
| 	mksyscall="go run mksyscall.go -openbsd -libc" | ||||
| 	mksysctl="go run mksysctl_openbsd.go" | ||||
| 	mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'" | ||||
| 	# Let the type of C char be signed for making the bare syscall | ||||
| 	# API consistent across platforms. | ||||
| 	mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char" | ||||
| @@ -214,11 +232,6 @@ esac | ||||
| 			if [ "$GOOSARCH" == "aix_ppc64" ]; then | ||||
| 				# aix/ppc64 script generates files instead of writing to stdin. | ||||
| 				echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ; | ||||
| 			elif [ "$GOOS" == "darwin" ]; then | ||||
| 			        # 1.12 and later, syscalls via libSystem | ||||
| 				echo "$mksyscall -tags $GOOS,$GOARCH,go1.12 $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go"; | ||||
| 				# 1.13 and later, syscalls via libSystem (including syscallPtr) | ||||
| 				echo "$mksyscall -tags $GOOS,$GOARCH,go1.13 syscall_darwin.1_13.go |gofmt >zsyscall_$GOOSARCH.1_13.go"; | ||||
| 			elif [ "$GOOS" == "illumos" ]; then | ||||
| 			        # illumos code generation requires a --illumos switch | ||||
| 			        echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go"; | ||||
| @@ -232,5 +245,5 @@ esac | ||||
| 	if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi | ||||
| 	if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi | ||||
| 	if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go | go run mkpost.go > ztypes_$GOOSARCH.go"; fi | ||||
| 	if [ -n "$mkasm" ]; then echo "$mkasm $GOARCH"; fi | ||||
| 	if [ -n "$mkasm" ]; then echo "$mkasm $GOOS $GOARCH"; fi | ||||
| ) | $run | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/mkerrors.sh
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/mkerrors.sh
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -642,7 +642,7 @@ errors=$( | ||||
| signals=$( | ||||
| 	echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | | ||||
| 	awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' | | ||||
| 	egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' | | ||||
| 	grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' | | ||||
| 	sort | ||||
| ) | ||||
|  | ||||
| @@ -652,7 +652,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags | | ||||
| 	sort >_error.grep | ||||
| echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | | ||||
| 	awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' | | ||||
| 	egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' | | ||||
| 	grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' | | ||||
| 	sort >_signal.grep | ||||
|  | ||||
| echo '// mkerrors.sh' "$@" | ||||
|   | ||||
							
								
								
									
										14
									
								
								vendor/golang.org/x/sys/unix/sockcmsg_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								vendor/golang.org/x/sys/unix/sockcmsg_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) { | ||||
| 	return msgs, nil | ||||
| } | ||||
|  | ||||
| // ParseOneSocketControlMessage parses a single socket control message from b, returning the message header, | ||||
| // message data (a slice of b), and the remainder of b after that single message. | ||||
| // When there are no remaining messages, len(remainder) == 0. | ||||
| func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) { | ||||
| 	h, dbuf, err := socketControlMessageHeaderAndData(b) | ||||
| 	if err != nil { | ||||
| 		return Cmsghdr{}, nil, nil, err | ||||
| 	} | ||||
| 	if i := cmsgAlignOf(int(h.Len)); i < len(b) { | ||||
| 		remainder = b[i:] | ||||
| 	} | ||||
| 	return *h, dbuf, remainder, nil | ||||
| } | ||||
|  | ||||
| func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) { | ||||
| 	h := (*Cmsghdr)(unsafe.Pointer(&b[0])) | ||||
| 	if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) { | ||||
|   | ||||
							
								
								
									
										27
									
								
								vendor/golang.org/x/sys/unix/str.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										27
									
								
								vendor/golang.org/x/sys/unix/str.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -1,27 +0,0 @@ | ||||
| // Copyright 2009 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris | ||||
| // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris | ||||
|  | ||||
| package unix | ||||
|  | ||||
| func itoa(val int) string { // do it here rather than with fmt to avoid dependency | ||||
| 	if val < 0 { | ||||
| 		return "-" + uitoa(uint(-val)) | ||||
| 	} | ||||
| 	return uitoa(uint(val)) | ||||
| } | ||||
|  | ||||
| func uitoa(val uint) string { | ||||
| 	var buf [32]byte // big enough for int64 | ||||
| 	i := len(buf) - 1 | ||||
| 	for val >= 10 { | ||||
| 		buf[i] = byte(val%10 + '0') | ||||
| 		i-- | ||||
| 		val /= 10 | ||||
| 	} | ||||
| 	buf[i] = byte(val + '0') | ||||
| 	return string(buf[i:]) | ||||
| } | ||||
							
								
								
									
										10
									
								
								vendor/golang.org/x/sys/unix/syscall.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								vendor/golang.org/x/sys/unix/syscall.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -29,8 +29,6 @@ import ( | ||||
| 	"bytes" | ||||
| 	"strings" | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.org/x/sys/internal/unsafeheader" | ||||
| ) | ||||
|  | ||||
| // ByteSliceFromString returns a NUL-terminated slice of bytes | ||||
| @@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string { | ||||
| 		ptr = unsafe.Pointer(uintptr(ptr) + 1) | ||||
| 	} | ||||
|  | ||||
| 	var s []byte | ||||
| 	h := (*unsafeheader.Slice)(unsafe.Pointer(&s)) | ||||
| 	h.Data = unsafe.Pointer(p) | ||||
| 	h.Len = n | ||||
| 	h.Cap = n | ||||
|  | ||||
| 	return string(s) | ||||
| 	return string(unsafe.Slice(p, n)) | ||||
| } | ||||
|  | ||||
| // Single-word zero for use when we need a valid pointer to 0 bytes. | ||||
|   | ||||
							
								
								
									
										2
									
								
								vendor/golang.org/x/sys/unix/syscall_aix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								vendor/golang.org/x/sys/unix/syscall_aix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -253,7 +253,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle | ||||
| 	var empty bool | ||||
| 	if len(oob) > 0 { | ||||
| 		// send at least one normal byte | ||||
| 		empty := emptyIovecs(iov) | ||||
| 		empty = emptyIovecs(iov) | ||||
| 		if empty { | ||||
| 			var iova [1]Iovec | ||||
| 			iova[0].Base = &dummy | ||||
|   | ||||
							
								
								
									
										2
									
								
								vendor/golang.org/x/sys/unix/syscall_bsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								vendor/golang.org/x/sys/unix/syscall_bsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -363,7 +363,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle | ||||
| 	var empty bool | ||||
| 	if len(oob) > 0 { | ||||
| 		// send at least one normal byte | ||||
| 		empty := emptyIovecs(iov) | ||||
| 		empty = emptyIovecs(iov) | ||||
| 		if empty { | ||||
| 			var iova [1]Iovec | ||||
| 			iova[0].Base = &dummy | ||||
|   | ||||
							
								
								
									
										32
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.1_12.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										32
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.1_12.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -1,32 +0,0 @@ | ||||
| // Copyright 2019 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build darwin && go1.12 && !go1.13 | ||||
| // +build darwin,go1.12,!go1.13 | ||||
|  | ||||
| package unix | ||||
|  | ||||
| import ( | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| const _SYS_GETDIRENTRIES64 = 344 | ||||
|  | ||||
| func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) { | ||||
| 	// To implement this using libSystem we'd need syscall_syscallPtr for | ||||
| 	// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall | ||||
| 	// back to raw syscalls for this func on Go 1.12. | ||||
| 	var p unsafe.Pointer | ||||
| 	if len(buf) > 0 { | ||||
| 		p = unsafe.Pointer(&buf[0]) | ||||
| 	} else { | ||||
| 		p = unsafe.Pointer(&_zero) | ||||
| 	} | ||||
| 	r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0) | ||||
| 	n = int(r0) | ||||
| 	if e1 != 0 { | ||||
| 		return n, errnoErr(e1) | ||||
| 	} | ||||
| 	return n, nil | ||||
| } | ||||
							
								
								
									
										108
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.1_13.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										108
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.1_13.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -1,108 +0,0 @@ | ||||
| // Copyright 2019 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build darwin && go1.13 | ||||
| // +build darwin,go1.13 | ||||
|  | ||||
| package unix | ||||
|  | ||||
| import ( | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.org/x/sys/internal/unsafeheader" | ||||
| ) | ||||
|  | ||||
| //sys	closedir(dir uintptr) (err error) | ||||
| //sys	readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) | ||||
|  | ||||
| func fdopendir(fd int) (dir uintptr, err error) { | ||||
| 	r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0) | ||||
| 	dir = uintptr(r0) | ||||
| 	if e1 != 0 { | ||||
| 		err = errnoErr(e1) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| var libc_fdopendir_trampoline_addr uintptr | ||||
|  | ||||
| //go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib" | ||||
|  | ||||
| func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) { | ||||
| 	// Simulate Getdirentries using fdopendir/readdir_r/closedir. | ||||
| 	// We store the number of entries to skip in the seek | ||||
| 	// offset of fd. See issue #31368. | ||||
| 	// It's not the full required semantics, but should handle the case | ||||
| 	// of calling Getdirentries or ReadDirent repeatedly. | ||||
| 	// It won't handle assigning the results of lseek to *basep, or handle | ||||
| 	// the directory being edited underfoot. | ||||
| 	skip, err := Seek(fd, 0, 1 /* SEEK_CUR */) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// We need to duplicate the incoming file descriptor | ||||
| 	// because the caller expects to retain control of it, but | ||||
| 	// fdopendir expects to take control of its argument. | ||||
| 	// Just Dup'ing the file descriptor is not enough, as the | ||||
| 	// result shares underlying state. Use Openat to make a really | ||||
| 	// new file descriptor referring to the same directory. | ||||
| 	fd2, err := Openat(fd, ".", O_RDONLY, 0) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	d, err := fdopendir(fd2) | ||||
| 	if err != nil { | ||||
| 		Close(fd2) | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	defer closedir(d) | ||||
|  | ||||
| 	var cnt int64 | ||||
| 	for { | ||||
| 		var entry Dirent | ||||
| 		var entryp *Dirent | ||||
| 		e := readdir_r(d, &entry, &entryp) | ||||
| 		if e != 0 { | ||||
| 			return n, errnoErr(e) | ||||
| 		} | ||||
| 		if entryp == nil { | ||||
| 			break | ||||
| 		} | ||||
| 		if skip > 0 { | ||||
| 			skip-- | ||||
| 			cnt++ | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		reclen := int(entry.Reclen) | ||||
| 		if reclen > len(buf) { | ||||
| 			// Not enough room. Return for now. | ||||
| 			// The counter will let us know where we should start up again. | ||||
| 			// Note: this strategy for suspending in the middle and | ||||
| 			// restarting is O(n^2) in the length of the directory. Oh well. | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// Copy entry into return buffer. | ||||
| 		var s []byte | ||||
| 		hdr := (*unsafeheader.Slice)(unsafe.Pointer(&s)) | ||||
| 		hdr.Data = unsafe.Pointer(&entry) | ||||
| 		hdr.Cap = reclen | ||||
| 		hdr.Len = reclen | ||||
| 		copy(buf, s) | ||||
|  | ||||
| 		buf = buf[reclen:] | ||||
| 		n += reclen | ||||
| 		cnt++ | ||||
| 	} | ||||
| 	// Set the seek offset of the input fd to record | ||||
| 	// how many files we've already returned. | ||||
| 	_, err = Seek(fd, cnt, 0 /* SEEK_SET */) | ||||
| 	if err != nil { | ||||
| 		return n, err | ||||
| 	} | ||||
|  | ||||
| 	return n, nil | ||||
| } | ||||
							
								
								
									
										91
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										91
									
								
								vendor/golang.org/x/sys/unix/syscall_darwin.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -19,6 +19,96 @@ import ( | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| //sys	closedir(dir uintptr) (err error) | ||||
| //sys	readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) | ||||
|  | ||||
| func fdopendir(fd int) (dir uintptr, err error) { | ||||
| 	r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0) | ||||
| 	dir = uintptr(r0) | ||||
| 	if e1 != 0 { | ||||
| 		err = errnoErr(e1) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| var libc_fdopendir_trampoline_addr uintptr | ||||
|  | ||||
| //go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib" | ||||
|  | ||||
| func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) { | ||||
| 	// Simulate Getdirentries using fdopendir/readdir_r/closedir. | ||||
| 	// We store the number of entries to skip in the seek | ||||
| 	// offset of fd. See issue #31368. | ||||
| 	// It's not the full required semantics, but should handle the case | ||||
| 	// of calling Getdirentries or ReadDirent repeatedly. | ||||
| 	// It won't handle assigning the results of lseek to *basep, or handle | ||||
| 	// the directory being edited underfoot. | ||||
| 	skip, err := Seek(fd, 0, 1 /* SEEK_CUR */) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// We need to duplicate the incoming file descriptor | ||||
| 	// because the caller expects to retain control of it, but | ||||
| 	// fdopendir expects to take control of its argument. | ||||
| 	// Just Dup'ing the file descriptor is not enough, as the | ||||
| 	// result shares underlying state. Use Openat to make a really | ||||
| 	// new file descriptor referring to the same directory. | ||||
| 	fd2, err := Openat(fd, ".", O_RDONLY, 0) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	d, err := fdopendir(fd2) | ||||
| 	if err != nil { | ||||
| 		Close(fd2) | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	defer closedir(d) | ||||
|  | ||||
| 	var cnt int64 | ||||
| 	for { | ||||
| 		var entry Dirent | ||||
| 		var entryp *Dirent | ||||
| 		e := readdir_r(d, &entry, &entryp) | ||||
| 		if e != 0 { | ||||
| 			return n, errnoErr(e) | ||||
| 		} | ||||
| 		if entryp == nil { | ||||
| 			break | ||||
| 		} | ||||
| 		if skip > 0 { | ||||
| 			skip-- | ||||
| 			cnt++ | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		reclen := int(entry.Reclen) | ||||
| 		if reclen > len(buf) { | ||||
| 			// Not enough room. Return for now. | ||||
| 			// The counter will let us know where we should start up again. | ||||
| 			// Note: this strategy for suspending in the middle and | ||||
| 			// restarting is O(n^2) in the length of the directory. Oh well. | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// Copy entry into return buffer. | ||||
| 		s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen) | ||||
| 		copy(buf, s) | ||||
|  | ||||
| 		buf = buf[reclen:] | ||||
| 		n += reclen | ||||
| 		cnt++ | ||||
| 	} | ||||
| 	// Set the seek offset of the input fd to record | ||||
| 	// how many files we've already returned. | ||||
| 	_, err = Seek(fd, cnt, 0 /* SEEK_SET */) | ||||
| 	if err != nil { | ||||
| 		return n, err | ||||
| 	} | ||||
|  | ||||
| 	return n, nil | ||||
| } | ||||
|  | ||||
| // SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets. | ||||
| type SockaddrDatalink struct { | ||||
| 	Len    uint8 | ||||
| @@ -140,6 +230,7 @@ func direntNamlen(buf []byte) (uint64, bool) { | ||||
|  | ||||
| func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) } | ||||
| func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) } | ||||
| func PtraceDenyAttach() (err error)    { return ptrace(PT_DENY_ATTACH, 0, 0, 0) } | ||||
|  | ||||
| //sysnb	pipe(p *[2]int32) (err error) | ||||
|  | ||||
|   | ||||
							
								
								
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_dragonfly.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_dragonfly.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -255,6 +255,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
| //sys	Chmod(path string, mode uint32) (err error) | ||||
| //sys	Chown(path string, uid int, gid int) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	Close(fd int) (err error) | ||||
| //sys	Dup(fd int) (nfd int, err error) | ||||
| //sys	Dup2(from int, to int) (err error) | ||||
|   | ||||
							
								
								
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -319,6 +319,7 @@ func PtraceSingleStep(pid int) (err error) { | ||||
| //sys	Chmod(path string, mode uint32) (err error) | ||||
| //sys	Chown(path string, uid int, gid int) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	Close(fd int) (err error) | ||||
| //sys	Dup(fd int) (nfd int, err error) | ||||
| //sys	Dup2(from int, to int) (err error) | ||||
|   | ||||
							
								
								
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_386.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_386.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) { | ||||
| 	return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0) | ||||
| } | ||||
|  | ||||
| func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)} | ||||
| func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{ | ||||
| 		Op:   int32(req), | ||||
| 		Offs: offs, | ||||
| 		Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe. | ||||
| 		Len:  uint32(countin), | ||||
| 	} | ||||
| 	err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0) | ||||
| 	return int(ioDesc.Len), err | ||||
| } | ||||
|   | ||||
							
								
								
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_amd64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_amd64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) { | ||||
| 	return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0) | ||||
| } | ||||
|  | ||||
| func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)} | ||||
| func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{ | ||||
| 		Op:   int32(req), | ||||
| 		Offs: offs, | ||||
| 		Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe. | ||||
| 		Len:  uint64(countin), | ||||
| 	} | ||||
| 	err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0) | ||||
| 	return int(ioDesc.Len), err | ||||
| } | ||||
|   | ||||
							
								
								
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_arm.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_arm.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
|  | ||||
| func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno) | ||||
|  | ||||
| func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)} | ||||
| func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{ | ||||
| 		Op:   int32(req), | ||||
| 		Offs: offs, | ||||
| 		Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe. | ||||
| 		Len:  uint32(countin), | ||||
| 	} | ||||
| 	err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0) | ||||
| 	return int(ioDesc.Len), err | ||||
| } | ||||
|   | ||||
							
								
								
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_arm64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_arm64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
|  | ||||
| func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno) | ||||
|  | ||||
| func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)} | ||||
| func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{ | ||||
| 		Op:   int32(req), | ||||
| 		Offs: offs, | ||||
| 		Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe. | ||||
| 		Len:  uint64(countin), | ||||
| 	} | ||||
| 	err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0) | ||||
| 	return int(ioDesc.Len), err | ||||
| } | ||||
|   | ||||
							
								
								
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								vendor/golang.org/x/sys/unix/syscall_freebsd_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
|  | ||||
| func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno) | ||||
|  | ||||
| func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)} | ||||
| func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) { | ||||
| 	ioDesc := PtraceIoDesc{ | ||||
| 		Op:   int32(req), | ||||
| 		Offs: offs, | ||||
| 		Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe. | ||||
| 		Len:  uint64(countin), | ||||
| 	} | ||||
| 	err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0) | ||||
| 	return int(ioDesc.Len), err | ||||
| } | ||||
|   | ||||
							
								
								
									
										22
									
								
								vendor/golang.org/x/sys/unix/syscall_hurd.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								vendor/golang.org/x/sys/unix/syscall_hurd.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| // Copyright 2022 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build hurd | ||||
| // +build hurd | ||||
|  | ||||
| package unix | ||||
|  | ||||
| /* | ||||
| #include <stdint.h> | ||||
| int ioctl(int, unsigned long int, uintptr_t); | ||||
| */ | ||||
| import "C" | ||||
|  | ||||
| func ioctl(fd int, req uint, arg uintptr) (err error) { | ||||
| 	r0, er := C.ioctl(C.int(fd), C.ulong(req), C.uintptr_t(arg)) | ||||
| 	if r0 == -1 && er != nil { | ||||
| 		err = er | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										29
									
								
								vendor/golang.org/x/sys/unix/syscall_hurd_386.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								vendor/golang.org/x/sys/unix/syscall_hurd_386.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| // Copyright 2022 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build 386 && hurd | ||||
| // +build 386,hurd | ||||
|  | ||||
| package unix | ||||
|  | ||||
| const ( | ||||
| 	TIOCGETA = 0x62251713 | ||||
| ) | ||||
|  | ||||
| type Winsize struct { | ||||
| 	Row    uint16 | ||||
| 	Col    uint16 | ||||
| 	Xpixel uint16 | ||||
| 	Ypixel uint16 | ||||
| } | ||||
|  | ||||
| type Termios struct { | ||||
| 	Iflag  uint32 | ||||
| 	Oflag  uint32 | ||||
| 	Cflag  uint32 | ||||
| 	Lflag  uint32 | ||||
| 	Cc     [20]uint8 | ||||
| 	Ispeed int32 | ||||
| 	Ospeed int32 | ||||
| } | ||||
							
								
								
									
										106
									
								
								vendor/golang.org/x/sys/unix/syscall_illumos.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										106
									
								
								vendor/golang.org/x/sys/unix/syscall_illumos.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -10,8 +10,6 @@ | ||||
| package unix | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"runtime" | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| @@ -79,107 +77,3 @@ func Accept4(fd int, flags int) (nfd int, sa Sockaddr, err error) { | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| //sys	putmsg(fd int, clptr *strbuf, dataptr *strbuf, flags int) (err error) | ||||
|  | ||||
| func Putmsg(fd int, cl []byte, data []byte, flags int) (err error) { | ||||
| 	var clp, datap *strbuf | ||||
| 	if len(cl) > 0 { | ||||
| 		clp = &strbuf{ | ||||
| 			Len: int32(len(cl)), | ||||
| 			Buf: (*int8)(unsafe.Pointer(&cl[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		datap = &strbuf{ | ||||
| 			Len: int32(len(data)), | ||||
| 			Buf: (*int8)(unsafe.Pointer(&data[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	return putmsg(fd, clp, datap, flags) | ||||
| } | ||||
|  | ||||
| //sys	getmsg(fd int, clptr *strbuf, dataptr *strbuf, flags *int) (err error) | ||||
|  | ||||
| func Getmsg(fd int, cl []byte, data []byte) (retCl []byte, retData []byte, flags int, err error) { | ||||
| 	var clp, datap *strbuf | ||||
| 	if len(cl) > 0 { | ||||
| 		clp = &strbuf{ | ||||
| 			Maxlen: int32(len(cl)), | ||||
| 			Buf:    (*int8)(unsafe.Pointer(&cl[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		datap = &strbuf{ | ||||
| 			Maxlen: int32(len(data)), | ||||
| 			Buf:    (*int8)(unsafe.Pointer(&data[0])), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err = getmsg(fd, clp, datap, &flags); err != nil { | ||||
| 		return nil, nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	if len(cl) > 0 { | ||||
| 		retCl = cl[:clp.Len] | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		retData = data[:datap.Len] | ||||
| 	} | ||||
| 	return retCl, retData, flags, nil | ||||
| } | ||||
|  | ||||
| func IoctlSetIntRetInt(fd int, req uint, arg int) (int, error) { | ||||
| 	return ioctlRet(fd, req, uintptr(arg)) | ||||
| } | ||||
|  | ||||
| func IoctlSetString(fd int, req uint, val string) error { | ||||
| 	bs := make([]byte, len(val)+1) | ||||
| 	copy(bs[:len(bs)-1], val) | ||||
| 	err := ioctl(fd, req, uintptr(unsafe.Pointer(&bs[0]))) | ||||
| 	runtime.KeepAlive(&bs[0]) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Lifreq Helpers | ||||
|  | ||||
| func (l *Lifreq) SetName(name string) error { | ||||
| 	if len(name) >= len(l.Name) { | ||||
| 		return fmt.Errorf("name cannot be more than %d characters", len(l.Name)-1) | ||||
| 	} | ||||
| 	for i := range name { | ||||
| 		l.Name[i] = int8(name[i]) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) SetLifruInt(d int) { | ||||
| 	*(*int)(unsafe.Pointer(&l.Lifru[0])) = d | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) GetLifruInt() int { | ||||
| 	return *(*int)(unsafe.Pointer(&l.Lifru[0])) | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) SetLifruUint(d uint) { | ||||
| 	*(*uint)(unsafe.Pointer(&l.Lifru[0])) = d | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) GetLifruUint() uint { | ||||
| 	return *(*uint)(unsafe.Pointer(&l.Lifru[0])) | ||||
| } | ||||
|  | ||||
| func IoctlLifreq(fd int, req uint, l *Lifreq) error { | ||||
| 	return ioctl(fd, req, uintptr(unsafe.Pointer(l))) | ||||
| } | ||||
|  | ||||
| // Strioctl Helpers | ||||
|  | ||||
| func (s *Strioctl) SetInt(i int) { | ||||
| 	s.Len = int32(unsafe.Sizeof(i)) | ||||
| 	s.Dp = (*int8)(unsafe.Pointer(&i)) | ||||
| } | ||||
|  | ||||
| func IoctlSetStrioctlRetInt(fd int, req uint, s *Strioctl) (int, error) { | ||||
| 	return ioctlRet(fd, req, uintptr(unsafe.Pointer(s))) | ||||
| } | ||||
|   | ||||
							
								
								
									
										97
									
								
								vendor/golang.org/x/sys/unix/syscall_linux.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										97
									
								
								vendor/golang.org/x/sys/unix/syscall_linux.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -13,6 +13,7 @@ package unix | ||||
|  | ||||
| import ( | ||||
| 	"encoding/binary" | ||||
| 	"strconv" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 	"unsafe" | ||||
| @@ -233,7 +234,7 @@ func Futimesat(dirfd int, path string, tv []Timeval) error { | ||||
| func Futimes(fd int, tv []Timeval) (err error) { | ||||
| 	// Believe it or not, this is the best we can do on Linux | ||||
| 	// (and is what glibc does). | ||||
| 	return Utimes("/proc/self/fd/"+itoa(fd), tv) | ||||
| 	return Utimes("/proc/self/fd/"+strconv.Itoa(fd), tv) | ||||
| } | ||||
|  | ||||
| const ImplementsGetwd = true | ||||
| @@ -1541,7 +1542,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle | ||||
| 	var dummy byte | ||||
| 	var empty bool | ||||
| 	if len(oob) > 0 { | ||||
| 		empty := emptyIovecs(iov) | ||||
| 		empty = emptyIovecs(iov) | ||||
| 		if empty { | ||||
| 			var sockType int | ||||
| 			sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) | ||||
| @@ -1553,6 +1554,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle | ||||
| 				var iova [1]Iovec | ||||
| 				iova[0].Base = &dummy | ||||
| 				iova[0].SetLen(1) | ||||
| 				iov = iova[:] | ||||
| 			} | ||||
| 		} | ||||
| 		msg.Control = &oob[0] | ||||
| @@ -1798,6 +1800,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
| //sysnb	Capset(hdr *CapUserHeader, data *CapUserData) (err error) | ||||
| //sys	Chdir(path string) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockAdjtime(clockid int32, buf *Timex) (state int, err error) | ||||
| //sys	ClockGetres(clockid int32, res *Timespec) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	ClockNanosleep(clockid int32, flags int, request *Timespec, remain *Timespec) (err error) | ||||
| @@ -1891,17 +1894,28 @@ func PrctlRetInt(option int, arg2 uintptr, arg3 uintptr, arg4 uintptr, arg5 uint | ||||
| 	return int(ret), nil | ||||
| } | ||||
|  | ||||
| // issue 1435. | ||||
| // On linux Setuid and Setgid only affects the current thread, not the process. | ||||
| // This does not match what most callers expect so we must return an error | ||||
| // here rather than letting the caller think that the call succeeded. | ||||
|  | ||||
| func Setuid(uid int) (err error) { | ||||
| 	return EOPNOTSUPP | ||||
| 	return syscall.Setuid(uid) | ||||
| } | ||||
|  | ||||
| func Setgid(uid int) (err error) { | ||||
| 	return EOPNOTSUPP | ||||
| func Setgid(gid int) (err error) { | ||||
| 	return syscall.Setgid(gid) | ||||
| } | ||||
|  | ||||
| func Setreuid(ruid, euid int) (err error) { | ||||
| 	return syscall.Setreuid(ruid, euid) | ||||
| } | ||||
|  | ||||
| func Setregid(rgid, egid int) (err error) { | ||||
| 	return syscall.Setregid(rgid, egid) | ||||
| } | ||||
|  | ||||
| func Setresuid(ruid, euid, suid int) (err error) { | ||||
| 	return syscall.Setresuid(ruid, euid, suid) | ||||
| } | ||||
|  | ||||
| func Setresgid(rgid, egid, sgid int) (err error) { | ||||
| 	return syscall.Setresgid(rgid, egid, sgid) | ||||
| } | ||||
|  | ||||
| // SetfsgidRetGid sets fsgid for current thread and returns previous fsgid set. | ||||
| @@ -1960,36 +1974,46 @@ func Signalfd(fd int, sigmask *Sigset_t, flags int) (newfd int, err error) { | ||||
| //sys	preadv2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PREADV2 | ||||
| //sys	pwritev2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PWRITEV2 | ||||
|  | ||||
| func bytes2iovec(bs [][]byte) []Iovec { | ||||
| 	iovecs := make([]Iovec, len(bs)) | ||||
| 	for i, b := range bs { | ||||
| 		iovecs[i].SetLen(len(b)) | ||||
| // minIovec is the size of the small initial allocation used by | ||||
| // Readv, Writev, etc. | ||||
| // | ||||
| // This small allocation gets stack allocated, which lets the | ||||
| // common use case of len(iovs) <= minIovs avoid more expensive | ||||
| // heap allocations. | ||||
| const minIovec = 8 | ||||
|  | ||||
| // appendBytes converts bs to Iovecs and appends them to vecs. | ||||
| func appendBytes(vecs []Iovec, bs [][]byte) []Iovec { | ||||
| 	for _, b := range bs { | ||||
| 		var v Iovec | ||||
| 		v.SetLen(len(b)) | ||||
| 		if len(b) > 0 { | ||||
| 			iovecs[i].Base = &b[0] | ||||
| 			v.Base = &b[0] | ||||
| 		} else { | ||||
| 			iovecs[i].Base = (*byte)(unsafe.Pointer(&_zero)) | ||||
| 			v.Base = (*byte)(unsafe.Pointer(&_zero)) | ||||
| 		} | ||||
| 		vecs = append(vecs, v) | ||||
| 	} | ||||
| 	return iovecs | ||||
| 	return vecs | ||||
| } | ||||
|  | ||||
| // offs2lohi splits offs into its lower and upper unsigned long. On 64-bit | ||||
| // systems, hi will always be 0. On 32-bit systems, offs will be split in half. | ||||
| // preadv/pwritev chose this calling convention so they don't need to add a | ||||
| // padding-register for alignment on ARM. | ||||
| // offs2lohi splits offs into its low and high order bits. | ||||
| func offs2lohi(offs int64) (lo, hi uintptr) { | ||||
| 	return uintptr(offs), uintptr(uint64(offs) >> SizeofLong) | ||||
| 	const longBits = SizeofLong * 8 | ||||
| 	return uintptr(offs), uintptr(uint64(offs) >> (longBits - 1) >> 1) // two shifts to avoid false positive in vet | ||||
| } | ||||
|  | ||||
| func Readv(fd int, iovs [][]byte) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	n, err = readv(fd, iovecs) | ||||
| 	readvRacedetect(iovecs, n, err) | ||||
| 	return n, err | ||||
| } | ||||
|  | ||||
| func Preadv(fd int, iovs [][]byte, offset int64) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	lo, hi := offs2lohi(offset) | ||||
| 	n, err = preadv(fd, iovecs, lo, hi) | ||||
| 	readvRacedetect(iovecs, n, err) | ||||
| @@ -1997,7 +2021,8 @@ func Preadv(fd int, iovs [][]byte, offset int64) (n int, err error) { | ||||
| } | ||||
|  | ||||
| func Preadv2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	lo, hi := offs2lohi(offset) | ||||
| 	n, err = preadv2(fd, iovecs, lo, hi, flags) | ||||
| 	readvRacedetect(iovecs, n, err) | ||||
| @@ -2024,7 +2049,8 @@ func readvRacedetect(iovecs []Iovec, n int, err error) { | ||||
| } | ||||
|  | ||||
| func Writev(fd int, iovs [][]byte) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	if raceenabled { | ||||
| 		raceReleaseMerge(unsafe.Pointer(&ioSync)) | ||||
| 	} | ||||
| @@ -2034,7 +2060,8 @@ func Writev(fd int, iovs [][]byte) (n int, err error) { | ||||
| } | ||||
|  | ||||
| func Pwritev(fd int, iovs [][]byte, offset int64) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	if raceenabled { | ||||
| 		raceReleaseMerge(unsafe.Pointer(&ioSync)) | ||||
| 	} | ||||
| @@ -2045,7 +2072,8 @@ func Pwritev(fd int, iovs [][]byte, offset int64) (n int, err error) { | ||||
| } | ||||
|  | ||||
| func Pwritev2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) { | ||||
| 	iovecs := bytes2iovec(iovs) | ||||
| 	iovecs := make([]Iovec, 0, minIovec) | ||||
| 	iovecs = appendBytes(iovecs, iovs) | ||||
| 	if raceenabled { | ||||
| 		raceReleaseMerge(unsafe.Pointer(&ioSync)) | ||||
| 	} | ||||
| @@ -2240,7 +2268,7 @@ func (fh *FileHandle) Bytes() []byte { | ||||
| 	if n == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&fh.fileHandle.Type)) + 4))[:n:n] | ||||
| 	return unsafe.Slice((*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&fh.fileHandle.Type))+4)), n) | ||||
| } | ||||
|  | ||||
| // NameToHandleAt wraps the name_to_handle_at system call; it obtains | ||||
| @@ -2356,6 +2384,16 @@ func Setitimer(which ItimerWhich, it Itimerval) (Itimerval, error) { | ||||
| 	return prev, nil | ||||
| } | ||||
|  | ||||
| //sysnb	rtSigprocmask(how int, set *Sigset_t, oldset *Sigset_t, sigsetsize uintptr) (err error) = SYS_RT_SIGPROCMASK | ||||
|  | ||||
| func PthreadSigmask(how int, set, oldset *Sigset_t) error { | ||||
| 	if oldset != nil { | ||||
| 		// Explicitly clear in case Sigset_t is larger than _C__NSIG. | ||||
| 		*oldset = Sigset_t{} | ||||
| 	} | ||||
| 	return rtSigprocmask(how, set, oldset, _C__NSIG/8) | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Unimplemented | ||||
|  */ | ||||
| @@ -2414,7 +2452,6 @@ func Setitimer(which ItimerWhich, it Itimerval) (Itimerval, error) { | ||||
| // RestartSyscall | ||||
| // RtSigaction | ||||
| // RtSigpending | ||||
| // RtSigprocmask | ||||
| // RtSigqueueinfo | ||||
| // RtSigreturn | ||||
| // RtSigsuspend | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_386.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_386.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -41,10 +41,6 @@ func setTimeval(sec, usec int64) Timeval { | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64 | ||||
| //sys	setfsgid(gid int) (prev int, err error) = SYS_SETFSGID32 | ||||
| //sys	setfsuid(uid int) (prev int, err error) = SYS_SETFSUID32 | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) = SYS_SETREGID32 | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) = SYS_SETRESGID32 | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) = SYS_SETRESUID32 | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) = SYS_SETREUID32 | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) = SYS_STAT64 | ||||
| //sys	SyncFileRange(fd int, off int64, n int64, flags int) (err error) | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_amd64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_amd64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -46,11 +46,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_arm.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_arm.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -62,10 +62,6 @@ func Seek(fd int, offset int64, whence int) (newoffset int64, err error) { | ||||
| //sys	Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) = SYS__NEWSELECT | ||||
| //sys	setfsgid(gid int) (prev int, err error) = SYS_SETFSGID32 | ||||
| //sys	setfsuid(uid int) (prev int, err error) = SYS_SETFSUID32 | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) = SYS_SETREGID32 | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) = SYS_SETRESGID32 | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) = SYS_SETRESUID32 | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) = SYS_SETREUID32 | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) = SYS_STAT64 | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_arm64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_arm64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -39,11 +39,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_loong64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_loong64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -34,10 +34,6 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_mips64x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_mips64x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -37,11 +37,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
| //sys	Statfs(path string, buf *Statfs_t) (err error) | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_mipsx.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_mipsx.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -32,10 +32,6 @@ func Syscall9(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64 | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error) | ||||
| //sys	SyncFileRange(fd int, off int64, n int64, flags int) (err error) | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_ppc.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_ppc.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -34,10 +34,6 @@ import ( | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = SYS_SENDFILE64 | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) = SYS_STAT64 | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_ppc64x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_ppc64x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -34,11 +34,7 @@ package unix | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -38,11 +38,7 @@ func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_s390x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_s390x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -34,11 +34,7 @@ import ( | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) | ||||
| //sys	Statfs(path string, buf *Statfs_t) (err error) | ||||
|   | ||||
							
								
								
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_sparc64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/golang.org/x/sys/unix/syscall_linux_sparc64.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -31,11 +31,7 @@ package unix | ||||
| //sys	sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) | ||||
| //sys	setfsgid(gid int) (prev int, err error) | ||||
| //sys	setfsuid(uid int) (prev int, err error) | ||||
| //sysnb	Setregid(rgid int, egid int) (err error) | ||||
| //sysnb	Setresgid(rgid int, egid int, sgid int) (err error) | ||||
| //sysnb	Setresuid(ruid int, euid int, suid int) (err error) | ||||
| //sysnb	Setrlimit(resource int, rlim *Rlimit) (err error) | ||||
| //sysnb	Setreuid(ruid int, euid int) (err error) | ||||
| //sys	Shutdown(fd int, how int) (err error) | ||||
| //sys	Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) | ||||
| //sys	Stat(path string, stat *Stat_t) (err error) | ||||
|   | ||||
							
								
								
									
										15
									
								
								vendor/golang.org/x/sys/unix/syscall_netbsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								vendor/golang.org/x/sys/unix/syscall_netbsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -110,6 +110,20 @@ func direntNamlen(buf []byte) (uint64, bool) { | ||||
| 	return readInt(buf, unsafe.Offsetof(Dirent{}.Namlen), unsafe.Sizeof(Dirent{}.Namlen)) | ||||
| } | ||||
|  | ||||
| func SysctlUvmexp(name string) (*Uvmexp, error) { | ||||
| 	mib, err := sysctlmib(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	n := uintptr(SizeofUvmexp) | ||||
| 	var u Uvmexp | ||||
| 	if err := sysctl(mib, (*byte)(unsafe.Pointer(&u)), &n, nil, 0); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &u, nil | ||||
| } | ||||
|  | ||||
| func Pipe(p []int) (err error) { | ||||
| 	return Pipe2(p, 0) | ||||
| } | ||||
| @@ -245,6 +259,7 @@ func Statvfs(path string, buf *Statvfs_t) (err error) { | ||||
| //sys	Chmod(path string, mode uint32) (err error) | ||||
| //sys	Chown(path string, uid int, gid int) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	Close(fd int) (err error) | ||||
| //sys	Dup(fd int) (nfd int, err error) | ||||
| //sys	Dup2(from int, to int) (err error) | ||||
|   | ||||
							
								
								
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -220,6 +220,7 @@ func Uname(uname *Utsname) error { | ||||
| //sys	Chmod(path string, mode uint32) (err error) | ||||
| //sys	Chown(path string, uid int, gid int) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	Close(fd int) (err error) | ||||
| //sys	Dup(fd int) (nfd int, err error) | ||||
| //sys	Dup2(from int, to int) (err error) | ||||
|   | ||||
							
								
								
									
										27
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_libc.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_libc.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| // Copyright 2022 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build openbsd | ||||
| // +build openbsd | ||||
|  | ||||
| package unix | ||||
|  | ||||
| import _ "unsafe" | ||||
|  | ||||
| // Implemented in the runtime package (runtime/sys_openbsd3.go) | ||||
| func syscall_syscall(fn, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) | ||||
| func syscall_syscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno) | ||||
| func syscall_syscall10(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 uintptr) (r1, r2 uintptr, err Errno) | ||||
| func syscall_rawSyscall(fn, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) | ||||
| func syscall_rawSyscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno) | ||||
|  | ||||
| //go:linkname syscall_syscall syscall.syscall | ||||
| //go:linkname syscall_syscall6 syscall.syscall6 | ||||
| //go:linkname syscall_syscall10 syscall.syscall10 | ||||
| //go:linkname syscall_rawSyscall syscall.rawSyscall | ||||
| //go:linkname syscall_rawSyscall6 syscall.rawSyscall6 | ||||
|  | ||||
| func syscall_syscall9(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err Errno) { | ||||
| 	return syscall_syscall10(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9, 0) | ||||
| } | ||||
							
								
								
									
										42
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_ppc64.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_ppc64.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| // Copyright 2019 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build ppc64 && openbsd | ||||
| // +build ppc64,openbsd | ||||
|  | ||||
| package unix | ||||
|  | ||||
| func setTimespec(sec, nsec int64) Timespec { | ||||
| 	return Timespec{Sec: sec, Nsec: nsec} | ||||
| } | ||||
|  | ||||
| func setTimeval(sec, usec int64) Timeval { | ||||
| 	return Timeval{Sec: sec, Usec: usec} | ||||
| } | ||||
|  | ||||
| func SetKevent(k *Kevent_t, fd, mode, flags int) { | ||||
| 	k.Ident = uint64(fd) | ||||
| 	k.Filter = int16(mode) | ||||
| 	k.Flags = uint16(flags) | ||||
| } | ||||
|  | ||||
| func (iov *Iovec) SetLen(length int) { | ||||
| 	iov.Len = uint64(length) | ||||
| } | ||||
|  | ||||
| func (msghdr *Msghdr) SetControllen(length int) { | ||||
| 	msghdr.Controllen = uint32(length) | ||||
| } | ||||
|  | ||||
| func (msghdr *Msghdr) SetIovlen(length int) { | ||||
| 	msghdr.Iovlen = uint32(length) | ||||
| } | ||||
|  | ||||
| func (cmsg *Cmsghdr) SetLen(length int) { | ||||
| 	cmsg.Len = uint32(length) | ||||
| } | ||||
|  | ||||
| // SYS___SYSCTL is used by syscall_bsd.go for all BSDs, but in modern versions | ||||
| // of openbsd/ppc64 the syscall is called sysctl instead of __sysctl. | ||||
| const SYS___SYSCTL = SYS_SYSCTL | ||||
							
								
								
									
										42
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								vendor/golang.org/x/sys/unix/syscall_openbsd_riscv64.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| // Copyright 2019 The Go Authors. All rights reserved. | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build riscv64 && openbsd | ||||
| // +build riscv64,openbsd | ||||
|  | ||||
| package unix | ||||
|  | ||||
| func setTimespec(sec, nsec int64) Timespec { | ||||
| 	return Timespec{Sec: sec, Nsec: nsec} | ||||
| } | ||||
|  | ||||
| func setTimeval(sec, usec int64) Timeval { | ||||
| 	return Timeval{Sec: sec, Usec: usec} | ||||
| } | ||||
|  | ||||
| func SetKevent(k *Kevent_t, fd, mode, flags int) { | ||||
| 	k.Ident = uint64(fd) | ||||
| 	k.Filter = int16(mode) | ||||
| 	k.Flags = uint16(flags) | ||||
| } | ||||
|  | ||||
| func (iov *Iovec) SetLen(length int) { | ||||
| 	iov.Len = uint64(length) | ||||
| } | ||||
|  | ||||
| func (msghdr *Msghdr) SetControllen(length int) { | ||||
| 	msghdr.Controllen = uint32(length) | ||||
| } | ||||
|  | ||||
| func (msghdr *Msghdr) SetIovlen(length int) { | ||||
| 	msghdr.Iovlen = uint32(length) | ||||
| } | ||||
|  | ||||
| func (cmsg *Cmsghdr) SetLen(length int) { | ||||
| 	cmsg.Len = uint32(length) | ||||
| } | ||||
|  | ||||
| // SYS___SYSCTL is used by syscall_bsd.go for all BSDs, but in modern versions | ||||
| // of openbsd/riscv64 the syscall is called sysctl instead of __sysctl. | ||||
| const SYS___SYSCTL = SYS_SYSCTL | ||||
							
								
								
									
										216
									
								
								vendor/golang.org/x/sys/unix/syscall_solaris.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										216
									
								
								vendor/golang.org/x/sys/unix/syscall_solaris.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -590,6 +590,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e | ||||
| //sys	Chmod(path string, mode uint32) (err error) | ||||
| //sys	Chown(path string, uid int, gid int) (err error) | ||||
| //sys	Chroot(path string) (err error) | ||||
| //sys	ClockGettime(clockid int32, time *Timespec) (err error) | ||||
| //sys	Close(fd int) (err error) | ||||
| //sys	Creat(path string, mode uint32) (fd int, err error) | ||||
| //sys	Dup(fd int) (nfd int, err error) | ||||
| @@ -750,8 +751,8 @@ type EventPort struct { | ||||
| 	// we should handle things gracefully. To do so, we need to keep an extra | ||||
| 	// reference to the cookie around until the event is processed | ||||
| 	// thus the otherwise seemingly extraneous "cookies" map | ||||
| 	// The key of this map is a pointer to the corresponding &fCookie.cookie | ||||
| 	cookies map[*interface{}]*fileObjCookie | ||||
| 	// The key of this map is a pointer to the corresponding fCookie | ||||
| 	cookies map[*fileObjCookie]struct{} | ||||
| } | ||||
|  | ||||
| // PortEvent is an abstraction of the port_event C struct. | ||||
| @@ -778,7 +779,7 @@ func NewEventPort() (*EventPort, error) { | ||||
| 		port:    port, | ||||
| 		fds:     make(map[uintptr]*fileObjCookie), | ||||
| 		paths:   make(map[string]*fileObjCookie), | ||||
| 		cookies: make(map[*interface{}]*fileObjCookie), | ||||
| 		cookies: make(map[*fileObjCookie]struct{}), | ||||
| 	} | ||||
| 	return e, nil | ||||
| } | ||||
| @@ -799,6 +800,7 @@ func (e *EventPort) Close() error { | ||||
| 	} | ||||
| 	e.fds = nil | ||||
| 	e.paths = nil | ||||
| 	e.cookies = nil | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -826,17 +828,16 @@ func (e *EventPort) AssociatePath(path string, stat os.FileInfo, events int, coo | ||||
| 	if _, found := e.paths[path]; found { | ||||
| 		return fmt.Errorf("%v is already associated with this Event Port", path) | ||||
| 	} | ||||
| 	fobj, err := createFileObj(path, stat) | ||||
| 	fCookie, err := createFileObjCookie(path, stat, cookie) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	fCookie := &fileObjCookie{fobj, cookie} | ||||
| 	_, err = port_associate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(fobj)), events, (*byte)(unsafe.Pointer(&fCookie.cookie))) | ||||
| 	_, err = port_associate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(fCookie.fobj)), events, (*byte)(unsafe.Pointer(fCookie))) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	e.paths[path] = fCookie | ||||
| 	e.cookies[&fCookie.cookie] = fCookie | ||||
| 	e.cookies[fCookie] = struct{}{} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -858,7 +859,7 @@ func (e *EventPort) DissociatePath(path string) error { | ||||
| 	if err == nil { | ||||
| 		// dissociate was successful, safe to delete the cookie | ||||
| 		fCookie := e.paths[path] | ||||
| 		delete(e.cookies, &fCookie.cookie) | ||||
| 		delete(e.cookies, fCookie) | ||||
| 	} | ||||
| 	delete(e.paths, path) | ||||
| 	return err | ||||
| @@ -871,13 +872,16 @@ func (e *EventPort) AssociateFd(fd uintptr, events int, cookie interface{}) erro | ||||
| 	if _, found := e.fds[fd]; found { | ||||
| 		return fmt.Errorf("%v is already associated with this Event Port", fd) | ||||
| 	} | ||||
| 	fCookie := &fileObjCookie{nil, cookie} | ||||
| 	_, err := port_associate(e.port, PORT_SOURCE_FD, fd, events, (*byte)(unsafe.Pointer(&fCookie.cookie))) | ||||
| 	fCookie, err := createFileObjCookie("", nil, cookie) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	_, err = port_associate(e.port, PORT_SOURCE_FD, fd, events, (*byte)(unsafe.Pointer(fCookie))) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	e.fds[fd] = fCookie | ||||
| 	e.cookies[&fCookie.cookie] = fCookie | ||||
| 	e.cookies[fCookie] = struct{}{} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -896,27 +900,31 @@ func (e *EventPort) DissociateFd(fd uintptr) error { | ||||
| 	if err == nil { | ||||
| 		// dissociate was successful, safe to delete the cookie | ||||
| 		fCookie := e.fds[fd] | ||||
| 		delete(e.cookies, &fCookie.cookie) | ||||
| 		delete(e.cookies, fCookie) | ||||
| 	} | ||||
| 	delete(e.fds, fd) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func createFileObj(name string, stat os.FileInfo) (*fileObj, error) { | ||||
| 	fobj := new(fileObj) | ||||
| 	bs, err := ByteSliceFromString(name) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| func createFileObjCookie(name string, stat os.FileInfo, cookie interface{}) (*fileObjCookie, error) { | ||||
| 	fCookie := new(fileObjCookie) | ||||
| 	fCookie.cookie = cookie | ||||
| 	if name != "" && stat != nil { | ||||
| 		fCookie.fobj = new(fileObj) | ||||
| 		bs, err := ByteSliceFromString(name) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		fCookie.fobj.Name = (*int8)(unsafe.Pointer(&bs[0])) | ||||
| 		s := stat.Sys().(*syscall.Stat_t) | ||||
| 		fCookie.fobj.Atim.Sec = s.Atim.Sec | ||||
| 		fCookie.fobj.Atim.Nsec = s.Atim.Nsec | ||||
| 		fCookie.fobj.Mtim.Sec = s.Mtim.Sec | ||||
| 		fCookie.fobj.Mtim.Nsec = s.Mtim.Nsec | ||||
| 		fCookie.fobj.Ctim.Sec = s.Ctim.Sec | ||||
| 		fCookie.fobj.Ctim.Nsec = s.Ctim.Nsec | ||||
| 	} | ||||
| 	fobj.Name = (*int8)(unsafe.Pointer(&bs[0])) | ||||
| 	s := stat.Sys().(*syscall.Stat_t) | ||||
| 	fobj.Atim.Sec = s.Atim.Sec | ||||
| 	fobj.Atim.Nsec = s.Atim.Nsec | ||||
| 	fobj.Mtim.Sec = s.Mtim.Sec | ||||
| 	fobj.Mtim.Nsec = s.Mtim.Nsec | ||||
| 	fobj.Ctim.Sec = s.Ctim.Sec | ||||
| 	fobj.Ctim.Nsec = s.Ctim.Nsec | ||||
| 	return fobj, nil | ||||
| 	return fCookie, nil | ||||
| } | ||||
|  | ||||
| // GetOne wraps port_get(3c) and returns a single PortEvent. | ||||
| @@ -929,44 +937,50 @@ func (e *EventPort) GetOne(t *Timespec) (*PortEvent, error) { | ||||
| 	p := new(PortEvent) | ||||
| 	e.mu.Lock() | ||||
| 	defer e.mu.Unlock() | ||||
| 	e.peIntToExt(pe, p) | ||||
| 	err = e.peIntToExt(pe, p) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return p, nil | ||||
| } | ||||
|  | ||||
| // peIntToExt converts a cgo portEvent struct into the friendlier PortEvent | ||||
| // NOTE: Always call this function while holding the e.mu mutex | ||||
| func (e *EventPort) peIntToExt(peInt *portEvent, peExt *PortEvent) { | ||||
| func (e *EventPort) peIntToExt(peInt *portEvent, peExt *PortEvent) error { | ||||
| 	if e.cookies == nil { | ||||
| 		return fmt.Errorf("this EventPort is already closed") | ||||
| 	} | ||||
| 	peExt.Events = peInt.Events | ||||
| 	peExt.Source = peInt.Source | ||||
| 	cookie := (*interface{})(unsafe.Pointer(peInt.User)) | ||||
| 	peExt.Cookie = *cookie | ||||
| 	fCookie := (*fileObjCookie)(unsafe.Pointer(peInt.User)) | ||||
| 	_, found := e.cookies[fCookie] | ||||
|  | ||||
| 	if !found { | ||||
| 		panic("unexpected event port address; may be due to kernel bug; see https://go.dev/issue/54254") | ||||
| 	} | ||||
| 	peExt.Cookie = fCookie.cookie | ||||
| 	delete(e.cookies, fCookie) | ||||
|  | ||||
| 	switch peInt.Source { | ||||
| 	case PORT_SOURCE_FD: | ||||
| 		delete(e.cookies, cookie) | ||||
| 		peExt.Fd = uintptr(peInt.Object) | ||||
| 		// Only remove the fds entry if it exists and this cookie matches | ||||
| 		if fobj, ok := e.fds[peExt.Fd]; ok { | ||||
| 			if &fobj.cookie == cookie { | ||||
| 			if fobj == fCookie { | ||||
| 				delete(e.fds, peExt.Fd) | ||||
| 			} | ||||
| 		} | ||||
| 	case PORT_SOURCE_FILE: | ||||
| 		if fCookie, ok := e.cookies[cookie]; ok && uintptr(unsafe.Pointer(fCookie.fobj)) == uintptr(peInt.Object) { | ||||
| 			// Use our stashed reference rather than using unsafe on what we got back | ||||
| 			// the unsafe version would be (*fileObj)(unsafe.Pointer(uintptr(peInt.Object))) | ||||
| 			peExt.fobj = fCookie.fobj | ||||
| 		} else { | ||||
| 			panic("mismanaged memory") | ||||
| 		} | ||||
| 		delete(e.cookies, cookie) | ||||
| 		peExt.fobj = fCookie.fobj | ||||
| 		peExt.Path = BytePtrToString((*byte)(unsafe.Pointer(peExt.fobj.Name))) | ||||
| 		// Only remove the paths entry if it exists and this cookie matches | ||||
| 		if fobj, ok := e.paths[peExt.Path]; ok { | ||||
| 			if &fobj.cookie == cookie { | ||||
| 			if fobj == fCookie { | ||||
| 				delete(e.paths, peExt.Path) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Pending wraps port_getn(3c) and returns how many events are pending. | ||||
| @@ -990,7 +1004,7 @@ func (e *EventPort) Get(s []PortEvent, min int, timeout *Timespec) (int, error) | ||||
| 	got := uint32(min) | ||||
| 	max := uint32(len(s)) | ||||
| 	var err error | ||||
| 	ps := make([]portEvent, max, max) | ||||
| 	ps := make([]portEvent, max) | ||||
| 	_, err = port_getn(e.port, &ps[0], max, &got, timeout) | ||||
| 	// got will be trustworthy with ETIME, but not any other error. | ||||
| 	if err != nil && err != ETIME { | ||||
| @@ -998,8 +1012,122 @@ func (e *EventPort) Get(s []PortEvent, min int, timeout *Timespec) (int, error) | ||||
| 	} | ||||
| 	e.mu.Lock() | ||||
| 	defer e.mu.Unlock() | ||||
| 	valid := 0 | ||||
| 	for i := 0; i < int(got); i++ { | ||||
| 		e.peIntToExt(&ps[i], &s[i]) | ||||
| 		err2 := e.peIntToExt(&ps[i], &s[i]) | ||||
| 		if err2 != nil { | ||||
| 			if valid == 0 && err == nil { | ||||
| 				// If err2 is the only error and there are no valid events | ||||
| 				// to return, return it to the caller. | ||||
| 				err = err2 | ||||
| 			} | ||||
| 			break | ||||
| 		} | ||||
| 		valid = i + 1 | ||||
| 	} | ||||
| 	return int(got), err | ||||
| 	return valid, err | ||||
| } | ||||
|  | ||||
| //sys	putmsg(fd int, clptr *strbuf, dataptr *strbuf, flags int) (err error) | ||||
|  | ||||
| func Putmsg(fd int, cl []byte, data []byte, flags int) (err error) { | ||||
| 	var clp, datap *strbuf | ||||
| 	if len(cl) > 0 { | ||||
| 		clp = &strbuf{ | ||||
| 			Len: int32(len(cl)), | ||||
| 			Buf: (*int8)(unsafe.Pointer(&cl[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		datap = &strbuf{ | ||||
| 			Len: int32(len(data)), | ||||
| 			Buf: (*int8)(unsafe.Pointer(&data[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	return putmsg(fd, clp, datap, flags) | ||||
| } | ||||
|  | ||||
| //sys	getmsg(fd int, clptr *strbuf, dataptr *strbuf, flags *int) (err error) | ||||
|  | ||||
| func Getmsg(fd int, cl []byte, data []byte) (retCl []byte, retData []byte, flags int, err error) { | ||||
| 	var clp, datap *strbuf | ||||
| 	if len(cl) > 0 { | ||||
| 		clp = &strbuf{ | ||||
| 			Maxlen: int32(len(cl)), | ||||
| 			Buf:    (*int8)(unsafe.Pointer(&cl[0])), | ||||
| 		} | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		datap = &strbuf{ | ||||
| 			Maxlen: int32(len(data)), | ||||
| 			Buf:    (*int8)(unsafe.Pointer(&data[0])), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err = getmsg(fd, clp, datap, &flags); err != nil { | ||||
| 		return nil, nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	if len(cl) > 0 { | ||||
| 		retCl = cl[:clp.Len] | ||||
| 	} | ||||
| 	if len(data) > 0 { | ||||
| 		retData = data[:datap.Len] | ||||
| 	} | ||||
| 	return retCl, retData, flags, nil | ||||
| } | ||||
|  | ||||
| func IoctlSetIntRetInt(fd int, req uint, arg int) (int, error) { | ||||
| 	return ioctlRet(fd, req, uintptr(arg)) | ||||
| } | ||||
|  | ||||
| func IoctlSetString(fd int, req uint, val string) error { | ||||
| 	bs := make([]byte, len(val)+1) | ||||
| 	copy(bs[:len(bs)-1], val) | ||||
| 	err := ioctl(fd, req, uintptr(unsafe.Pointer(&bs[0]))) | ||||
| 	runtime.KeepAlive(&bs[0]) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Lifreq Helpers | ||||
|  | ||||
| func (l *Lifreq) SetName(name string) error { | ||||
| 	if len(name) >= len(l.Name) { | ||||
| 		return fmt.Errorf("name cannot be more than %d characters", len(l.Name)-1) | ||||
| 	} | ||||
| 	for i := range name { | ||||
| 		l.Name[i] = int8(name[i]) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) SetLifruInt(d int) { | ||||
| 	*(*int)(unsafe.Pointer(&l.Lifru[0])) = d | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) GetLifruInt() int { | ||||
| 	return *(*int)(unsafe.Pointer(&l.Lifru[0])) | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) SetLifruUint(d uint) { | ||||
| 	*(*uint)(unsafe.Pointer(&l.Lifru[0])) = d | ||||
| } | ||||
|  | ||||
| func (l *Lifreq) GetLifruUint() uint { | ||||
| 	return *(*uint)(unsafe.Pointer(&l.Lifru[0])) | ||||
| } | ||||
|  | ||||
| func IoctlLifreq(fd int, req uint, l *Lifreq) error { | ||||
| 	return ioctl(fd, req, uintptr(unsafe.Pointer(l))) | ||||
| } | ||||
|  | ||||
| // Strioctl Helpers | ||||
|  | ||||
| func (s *Strioctl) SetInt(i int) { | ||||
| 	s.Len = int32(unsafe.Sizeof(i)) | ||||
| 	s.Dp = (*int8)(unsafe.Pointer(&i)) | ||||
| } | ||||
|  | ||||
| func IoctlSetStrioctlRetInt(fd int, req uint, s *Strioctl) (int, error) { | ||||
| 	return ioctlRet(fd, req, uintptr(unsafe.Pointer(s))) | ||||
| } | ||||
|   | ||||
							
								
								
									
										77
									
								
								vendor/golang.org/x/sys/unix/syscall_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										77
									
								
								vendor/golang.org/x/sys/unix/syscall_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -13,8 +13,6 @@ import ( | ||||
| 	"sync" | ||||
| 	"syscall" | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.org/x/sys/internal/unsafeheader" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -117,11 +115,7 @@ func (m *mmapper) Mmap(fd int, offset int64, length int, prot int, flags int) (d | ||||
| 	} | ||||
|  | ||||
| 	// Use unsafe to convert addr into a []byte. | ||||
| 	var b []byte | ||||
| 	hdr := (*unsafeheader.Slice)(unsafe.Pointer(&b)) | ||||
| 	hdr.Data = unsafe.Pointer(addr) | ||||
| 	hdr.Cap = length | ||||
| 	hdr.Len = length | ||||
| 	b := unsafe.Slice((*byte)(unsafe.Pointer(addr)), length) | ||||
|  | ||||
| 	// Register mapping in m and return it. | ||||
| 	p := &b[cap(b)-1] | ||||
| @@ -337,6 +331,19 @@ func Recvfrom(fd int, p []byte, flags int) (n int, from Sockaddr, err error) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // Recvmsg receives a message from a socket using the recvmsg system call. The | ||||
| // received non-control data will be written to p, and any "out of band" | ||||
| // control data will be written to oob. The flags are passed to recvmsg. | ||||
| // | ||||
| // The results are: | ||||
| //   - n is the number of non-control data bytes read into p | ||||
| //   - oobn is the number of control data bytes read into oob; this may be interpreted using [ParseSocketControlMessage] | ||||
| //   - recvflags is flags returned by recvmsg | ||||
| //   - from is the address of the sender | ||||
| // | ||||
| // If the underlying socket type is not SOCK_DGRAM, a received message | ||||
| // containing oob data and a single '\0' of non-control data is treated as if | ||||
| // the message contained only control data, i.e. n will be zero on return. | ||||
| func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) { | ||||
| 	var iov [1]Iovec | ||||
| 	if len(p) > 0 { | ||||
| @@ -352,13 +359,9 @@ func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // RecvmsgBuffers receives a message from a socket using the recvmsg | ||||
| // system call. The flags are passed to recvmsg. Any non-control data | ||||
| // read is scattered into the buffers slices. The results are: | ||||
| //   - n is the number of non-control data read into bufs | ||||
| //   - oobn is the number of control data read into oob; this may be interpreted using [ParseSocketControlMessage] | ||||
| //   - recvflags is flags returned by recvmsg | ||||
| //   - from is the address of the sender | ||||
| // RecvmsgBuffers receives a message from a socket using the recvmsg system | ||||
| // call. This function is equivalent to Recvmsg, but non-control data read is | ||||
| // scattered into the buffers slices. | ||||
| func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) { | ||||
| 	iov := make([]Iovec, len(buffers)) | ||||
| 	for i := range buffers { | ||||
| @@ -377,11 +380,38 @@ func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn in | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // Sendmsg sends a message on a socket to an address using the sendmsg system | ||||
| // call. This function is equivalent to SendmsgN, but does not return the | ||||
| // number of bytes actually sent. | ||||
| func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) { | ||||
| 	_, err = SendmsgN(fd, p, oob, to, flags) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // SendmsgN sends a message on a socket to an address using the sendmsg system | ||||
| // call. p contains the non-control data to send, and oob contains the "out of | ||||
| // band" control data. The flags are passed to sendmsg. The number of | ||||
| // non-control bytes actually written to the socket is returned. | ||||
| // | ||||
| // Some socket types do not support sending control data without accompanying | ||||
| // non-control data. If p is empty, and oob contains control data, and the | ||||
| // underlying socket type is not SOCK_DGRAM, p will be treated as containing a | ||||
| // single '\0' and the return value will indicate zero bytes sent. | ||||
| // | ||||
| // The Go function Recvmsg, if called with an empty p and a non-empty oob, | ||||
| // will read and ignore this additional '\0'.  If the message is received by | ||||
| // code that does not use Recvmsg, or that does not use Go at all, that code | ||||
| // will need to be written to expect and ignore the additional '\0'. | ||||
| // | ||||
| // If you need to send non-empty oob with p actually empty, and if the | ||||
| // underlying socket type supports it, you can do so via a raw system call as | ||||
| // follows: | ||||
| // | ||||
| //	msg := &unix.Msghdr{ | ||||
| //	    Control: &oob[0], | ||||
| //	} | ||||
| //	msg.SetControllen(len(oob)) | ||||
| //	n, _, errno := unix.Syscall(unix.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), flags) | ||||
| func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) { | ||||
| 	var iov [1]Iovec | ||||
| 	if len(p) > 0 { | ||||
| @@ -400,9 +430,8 @@ func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) | ||||
| } | ||||
|  | ||||
| // SendmsgBuffers sends a message on a socket to an address using the sendmsg | ||||
| // system call. The flags are passed to sendmsg. Any non-control data written | ||||
| // is gathered from buffers. The function returns the number of bytes written | ||||
| // to the socket. | ||||
| // system call. This function is equivalent to SendmsgN, but the non-control | ||||
| // data is gathered from buffers. | ||||
| func SendmsgBuffers(fd int, buffers [][]byte, oob []byte, to Sockaddr, flags int) (n int, err error) { | ||||
| 	iov := make([]Iovec, len(buffers)) | ||||
| 	for i := range buffers { | ||||
| @@ -429,11 +458,15 @@ func Send(s int, buf []byte, flags int) (err error) { | ||||
| } | ||||
|  | ||||
| func Sendto(fd int, p []byte, flags int, to Sockaddr) (err error) { | ||||
| 	ptr, n, err := to.sockaddr() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	var ptr unsafe.Pointer | ||||
| 	var salen _Socklen | ||||
| 	if to != nil { | ||||
| 		ptr, salen, err = to.sockaddr() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return sendto(fd, p, flags, ptr, n) | ||||
| 	return sendto(fd, p, flags, ptr, salen) | ||||
| } | ||||
|  | ||||
| func SetsockoptByte(fd, level, opt int, value byte) (err error) { | ||||
| @@ -545,7 +578,7 @@ func Lutimes(path string, tv []Timeval) error { | ||||
| 	return UtimesNanoAt(AT_FDCWD, path, ts, AT_SYMLINK_NOFOLLOW) | ||||
| } | ||||
|  | ||||
| // emptyIovec reports whether there are no bytes in the slice of Iovec. | ||||
| // emptyIovecs reports whether there are no bytes in the slice of Iovec. | ||||
| func emptyIovecs(iov []Iovec) bool { | ||||
| 	for i := range iov { | ||||
| 		if iov[i].Len > 0 { | ||||
|   | ||||
							
								
								
									
										6
									
								
								vendor/golang.org/x/sys/unix/syscall_unix_gc.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								vendor/golang.org/x/sys/unix/syscall_unix_gc.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -2,11 +2,9 @@ | ||||
| // Use of this source code is governed by a BSD-style | ||||
| // license that can be found in the LICENSE file. | ||||
|  | ||||
| //go:build (darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris) && gc && !ppc64le && !ppc64 | ||||
| // +build darwin dragonfly freebsd linux netbsd openbsd solaris | ||||
| //go:build (darwin || dragonfly || freebsd || (linux && !ppc64 && !ppc64le) || netbsd || openbsd || solaris) && gc | ||||
| // +build darwin dragonfly freebsd linux,!ppc64,!ppc64le netbsd openbsd solaris | ||||
| // +build gc | ||||
| // +build !ppc64le | ||||
| // +build !ppc64 | ||||
|  | ||||
| package unix | ||||
|  | ||||
|   | ||||
							
								
								
									
										173
									
								
								vendor/golang.org/x/sys/unix/syscall_zos_s390x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										173
									
								
								vendor/golang.org/x/sys/unix/syscall_zos_s390x.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -9,8 +9,10 @@ package unix | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"runtime" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"syscall" | ||||
| 	"unsafe" | ||||
| @@ -55,7 +57,13 @@ func (d *Dirent) NameString() string { | ||||
| 	if d == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return string(d.Name[:d.Namlen]) | ||||
| 	s := string(d.Name[:]) | ||||
| 	idx := strings.IndexByte(s, 0) | ||||
| 	if idx == -1 { | ||||
| 		return s | ||||
| 	} else { | ||||
| 		return s[:idx] | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (sa *SockaddrInet4) sockaddr() (unsafe.Pointer, _Socklen, error) { | ||||
| @@ -1230,6 +1238,14 @@ func Readdir(dir uintptr) (*Dirent, error) { | ||||
| 	return &ent, err | ||||
| } | ||||
|  | ||||
| func readdir_r(dirp uintptr, entry *direntLE, result **direntLE) (err error) { | ||||
| 	r0, _, e1 := syscall_syscall(SYS___READDIR_R_A, dirp, uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result))) | ||||
| 	if int64(r0) == -1 { | ||||
| 		err = errnoErr(Errno(e1)) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func Closedir(dir uintptr) error { | ||||
| 	_, _, e := syscall_syscall(SYS_CLOSEDIR, dir, 0, 0) | ||||
| 	if e != 0 { | ||||
| @@ -1821,3 +1837,158 @@ func Unmount(name string, mtm int) (err error) { | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func fdToPath(dirfd int) (path string, err error) { | ||||
| 	var buffer [1024]byte | ||||
| 	// w_ctrl() | ||||
| 	ret := runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS_W_IOCTL<<4, | ||||
| 		[]uintptr{uintptr(dirfd), 17, 1024, uintptr(unsafe.Pointer(&buffer[0]))}) | ||||
| 	if ret == 0 { | ||||
| 		zb := bytes.IndexByte(buffer[:], 0) | ||||
| 		if zb == -1 { | ||||
| 			zb = len(buffer) | ||||
| 		} | ||||
| 		// __e2a_l() | ||||
| 		runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___E2A_L<<4, | ||||
| 			[]uintptr{uintptr(unsafe.Pointer(&buffer[0])), uintptr(zb)}) | ||||
| 		return string(buffer[:zb]), nil | ||||
| 	} | ||||
| 	// __errno() | ||||
| 	errno := int(*(*int32)(unsafe.Pointer(runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___ERRNO<<4, | ||||
| 		[]uintptr{})))) | ||||
| 	// __errno2() | ||||
| 	errno2 := int(runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS___ERRNO2<<4, | ||||
| 		[]uintptr{})) | ||||
| 	// strerror_r() | ||||
| 	ret = runtime.CallLeFuncByPtr(runtime.XplinkLibvec+SYS_STRERROR_R<<4, | ||||
| 		[]uintptr{uintptr(errno), uintptr(unsafe.Pointer(&buffer[0])), 1024}) | ||||
| 	if ret == 0 { | ||||
| 		zb := bytes.IndexByte(buffer[:], 0) | ||||
| 		if zb == -1 { | ||||
| 			zb = len(buffer) | ||||
| 		} | ||||
| 		return "", fmt.Errorf("%s (errno2=0x%x)", buffer[:zb], errno2) | ||||
| 	} else { | ||||
| 		return "", fmt.Errorf("fdToPath errno %d (errno2=0x%x)", errno, errno2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func direntLeToDirentUnix(dirent *direntLE, dir uintptr, path string) (Dirent, error) { | ||||
| 	var d Dirent | ||||
|  | ||||
| 	d.Ino = uint64(dirent.Ino) | ||||
| 	offset, err := Telldir(dir) | ||||
| 	if err != nil { | ||||
| 		return d, err | ||||
| 	} | ||||
|  | ||||
| 	d.Off = int64(offset) | ||||
| 	s := string(bytes.Split(dirent.Name[:], []byte{0})[0]) | ||||
| 	copy(d.Name[:], s) | ||||
|  | ||||
| 	d.Reclen = uint16(24 + len(d.NameString())) | ||||
| 	var st Stat_t | ||||
| 	path = path + "/" + s | ||||
| 	err = Lstat(path, &st) | ||||
| 	if err != nil { | ||||
| 		return d, err | ||||
| 	} | ||||
|  | ||||
| 	d.Type = uint8(st.Mode >> 24) | ||||
| 	return d, err | ||||
| } | ||||
|  | ||||
| func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) { | ||||
| 	// Simulation of Getdirentries port from the Darwin implementation. | ||||
| 	// COMMENTS FROM DARWIN: | ||||
| 	// It's not the full required semantics, but should handle the case | ||||
| 	// of calling Getdirentries or ReadDirent repeatedly. | ||||
| 	// It won't handle assigning the results of lseek to *basep, or handle | ||||
| 	// the directory being edited underfoot. | ||||
|  | ||||
| 	skip, err := Seek(fd, 0, 1 /* SEEK_CUR */) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	// Get path from fd to avoid unavailable call (fdopendir) | ||||
| 	path, err := fdToPath(fd) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	d, err := Opendir(path) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	defer Closedir(d) | ||||
|  | ||||
| 	var cnt int64 | ||||
| 	for { | ||||
| 		var entryLE direntLE | ||||
| 		var entrypLE *direntLE | ||||
| 		e := readdir_r(d, &entryLE, &entrypLE) | ||||
| 		if e != nil { | ||||
| 			return n, e | ||||
| 		} | ||||
| 		if entrypLE == nil { | ||||
| 			break | ||||
| 		} | ||||
| 		if skip > 0 { | ||||
| 			skip-- | ||||
| 			cnt++ | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Dirent on zos has a different structure | ||||
| 		entry, e := direntLeToDirentUnix(&entryLE, d, path) | ||||
| 		if e != nil { | ||||
| 			return n, e | ||||
| 		} | ||||
|  | ||||
| 		reclen := int(entry.Reclen) | ||||
| 		if reclen > len(buf) { | ||||
| 			// Not enough room. Return for now. | ||||
| 			// The counter will let us know where we should start up again. | ||||
| 			// Note: this strategy for suspending in the middle and | ||||
| 			// restarting is O(n^2) in the length of the directory. Oh well. | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// Copy entry into return buffer. | ||||
| 		s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen) | ||||
| 		copy(buf, s) | ||||
|  | ||||
| 		buf = buf[reclen:] | ||||
| 		n += reclen | ||||
| 		cnt++ | ||||
| 	} | ||||
| 	// Set the seek offset of the input fd to record | ||||
| 	// how many files we've already returned. | ||||
| 	_, err = Seek(fd, cnt, 0 /* SEEK_SET */) | ||||
| 	if err != nil { | ||||
| 		return n, err | ||||
| 	} | ||||
|  | ||||
| 	return n, nil | ||||
| } | ||||
|  | ||||
| func ReadDirent(fd int, buf []byte) (n int, err error) { | ||||
| 	var base = (*uintptr)(unsafe.Pointer(new(uint64))) | ||||
| 	return Getdirentries(fd, buf, base) | ||||
| } | ||||
|  | ||||
| func direntIno(buf []byte) (uint64, bool) { | ||||
| 	return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino)) | ||||
| } | ||||
|  | ||||
| func direntReclen(buf []byte) (uint64, bool) { | ||||
| 	return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen)) | ||||
| } | ||||
|  | ||||
| func direntNamlen(buf []byte) (uint64, bool) { | ||||
| 	reclen, ok := direntReclen(buf) | ||||
| 	if !ok { | ||||
| 		return 0, false | ||||
| 	} | ||||
| 	return reclen - uint64(unsafe.Offsetof(Dirent{}.Name)), true | ||||
| } | ||||
|   | ||||
							
								
								
									
										13
									
								
								vendor/golang.org/x/sys/unix/sysvshm_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								vendor/golang.org/x/sys/unix/sysvshm_unix.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -7,11 +7,7 @@ | ||||
|  | ||||
| package unix | ||||
|  | ||||
| import ( | ||||
| 	"unsafe" | ||||
|  | ||||
| 	"golang.org/x/sys/internal/unsafeheader" | ||||
| ) | ||||
| import "unsafe" | ||||
|  | ||||
| // SysvShmAttach attaches the Sysv shared memory segment associated with the | ||||
| // shared memory identifier id. | ||||
| @@ -34,12 +30,7 @@ func SysvShmAttach(id int, addr uintptr, flag int) ([]byte, error) { | ||||
| 	} | ||||
|  | ||||
| 	// Use unsafe to convert addr into a []byte. | ||||
| 	// TODO: convert to unsafe.Slice once we can assume Go 1.17 | ||||
| 	var b []byte | ||||
| 	hdr := (*unsafeheader.Slice)(unsafe.Pointer(&b)) | ||||
| 	hdr.Data = unsafe.Pointer(addr) | ||||
| 	hdr.Cap = int(info.Segsz) | ||||
| 	hdr.Len = int(info.Segsz) | ||||
| 	b := unsafe.Slice((*byte)(unsafe.Pointer(addr)), int(info.Segsz)) | ||||
| 	return b, nil | ||||
| } | ||||
|  | ||||
|   | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user