mirror of
				https://github.com/mochi-mqtt/server.git
				synced 2025-10-26 17:40:38 +08:00 
			
		
		
		
	Compare commits
	
		
			194 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | d8f28cb843 | ||
|   | 88861c219d | ||
|   | 7ba6cf28d9 | ||
|   | c174cfdc6b | ||
|   | 4f198a99dd | ||
|   | 2a9c9fcc40 | ||
|   | 835a85c8bf | ||
|   | fe5d9ffa61 | ||
|   | aac186dcc1 | ||
|   | 42931f332f | ||
|   | 8a04648c09 | ||
|   | 854c033fb6 | ||
|   | 74ed8cd046 | ||
|   | be164fa715 | ||
|   | 4287955161 | ||
|   | bbf08ff496 | ||
|   | c38201ff8b | ||
|   | f8b4ff5c0d | ||
|   | 661e23e051 | ||
|   | 9c99db426c | ||
|   | 40b7273a53 | ||
|   | 898c90d4ca | ||
|   | bc3d8b0eaa | ||
|   | 35bd928714 | ||
|   | 20c2655d0e | ||
|   | fec25f29c6 | ||
|   | 1d7a322229 | ||
|   | d8b38a4ae2 | ||
|   | a83c0c4fd0 | ||
|   | 66a1d19e89 | ||
|   | 0dbebbc066 | ||
|   | f22b8276e8 | ||
|   | d60c438960 | ||
|   | b2fc287a98 | ||
|   | 3e3ba20b08 | ||
|   | 9ee462c777 | ||
|   | 3c89114bba | ||
|   | ecbd07fa3a | ||
|   | ad8bf2a931 | ||
|   | b8fb068bb9 | ||
|   | c1348a37b8 | ||
|   | 84fc2f848b | ||
|   | 8703d6d020 | ||
|   | 666440fe56 | ||
|   | 1ae050939a | ||
|   | f4683d27d0 | ||
|   | dff2b1db30 | ||
|   | 9de6b4e427 | ||
|   | 78c1914270 | ||
|   | f71bf5c3d6 | ||
|   | 53c4a6b09f | ||
|   | a02c6bd8df | ||
|   | d8f6d63cc8 | ||
|   | bef13eec20 | ||
|   | 27f3c484ad | ||
|   | 9b5cdb0bcc | ||
|   | 2b60a11d4a | ||
|   | b53774f818 | ||
|   | 7dee729afb | ||
|   | aed535b7bf | ||
|   | 4ff888ab3b | ||
|   | 31252c081b | ||
|   | 3a15cc3add | ||
|   | 765f6e7c2e | ||
|   | e3bdfc1f8e | ||
|   | 60cd972b7f | ||
|   | ee081d0abe | ||
|   | d97b4bb81d | ||
|   | 94aeacf0cb | ||
|   | 5cb8a081a1 | ||
|   | fc00112e47 | ||
|   | bbc22fae5b | ||
|   | 82cb75913d | ||
|   | 45c4a64b87 | ||
|   | bae2579497 | ||
|   | 1bc01271cb | ||
|   | 352a71f50c | ||
|   | 6f9f62e38f | ||
|   | 0f67d9e8ff | ||
|   | f2dd5b63ae | ||
|   | 54e2d044a2 | ||
|   | 6298a87298 | ||
|   | b6fd25bba4 | ||
|   | eef3592576 | ||
|   | 5d343c12e1 | ||
|   | 70f52c8a3b | ||
|   | 429b72265a | ||
|   | f60d2dcfca | ||
|   | 6674cd64eb | ||
|   | f218cde69c | ||
|   | 9ea687eb94 | ||
|   | 949e4e2e91 | ||
|   | 515e0269de | ||
|   | ae6073c79c | ||
|   | b072a08f0b | ||
|   | 01d8a450d2 | ||
|   | da2fd41f79 | ||
|   | 56e8039093 | ||
|   | 7b9bc844c1 | ||
|   | 8acb182820 | ||
|   | 5726880095 | ||
|   | ee459e1b3d | ||
|   | 6aec3a8bbf | ||
|   | 74699f0a87 | ||
|   | 70def39ff9 | ||
|   | 8e7098a32d | ||
|   | e4f02919fd | ||
|   | 99c96c844e | ||
|   | 18629aea6d | ||
|   | a0060429d1 | ||
|   | d946a9ae16 | ||
|   | 51a2eb5f48 | ||
|   | 0d4b0a89d8 | ||
|   | 0e7ccfe3fb | ||
|   | 5d7230630d | ||
|   | 6a3cbd6093 | ||
|   | 7b4e79707b | ||
|   | c6643592f6 | ||
|   | 5de12d0460 | ||
|   | 0a7205e110 | ||
|   | 8133dd8299 | ||
|   | fdbfff57dc | ||
|   | f5fc5e8c44 | ||
|   | 9f44712b80 | ||
|   | 1f86168d9d | ||
|   | ab25083ed2 | ||
|   | 9b0aa4d559 | ||
|   | 03814944a9 | ||
|   | 3286d5a484 | ||
|   | 7e970d3c7a | ||
|   | d6a92cc5bd | ||
|   | 325d44d478 | ||
|   | 0a5f6d3a9d | ||
|   | 17253ad8bd | ||
|   | 9f1c387091 | ||
|   | 9c6f602630 | ||
|   | b0dcaabdde | ||
|   | 460f0ef681 | ||
|   | 6e16765f60 | ||
|   | 2b361df19e | ||
|   | c8c0a5a094 | ||
|   | 4a833dd081 | ||
|   | 81198d9845 | ||
|   | 6c12d8a71a | ||
|   | 19b598b672 | ||
|   | b6529f05d3 | ||
|   | 7f76445cc8 | ||
|   | b1c01792cd | ||
|   | eda03d4338 | ||
|   | 18070f1f57 | ||
|   | 7f10c28a37 | ||
|   | 122531bb27 | ||
|   | e6dbcae428 | ||
|   | 98875de568 | ||
|   | c9fd9451af | ||
|   | 6550b8d680 | ||
|   | a60c96c889 | ||
|   | 86e0a5827e | ||
|   | 06c399b606 | ||
|   | ed117f67a1 | ||
|   | 880a3299e1 | ||
|   | 1c408d05be | ||
|   | fce495f83e | ||
|   | 471ca00a64 | ||
|   | a2c0749640 | ||
|   | 37293aeecf | ||
|   | 7a2d4db6a4 | ||
|   | 03d2a8bc82 | ||
|   | 4b51e5c7d1 | ||
|   | d15ad682bf | ||
|   | 130ffcbb53 | ||
|   | 33cf2f991b | ||
|   | a360ea6a6c | ||
|   | ae3aa0d3fa | ||
|   | 811ae0e1be | ||
|   | 51d6825430 | ||
|   | 514288c53e | ||
|   | 957fc0a049 | ||
|   | 03f94f948a | ||
|   | 1bc752a2b8 | ||
|   | b9db59ba12 | ||
|   | c0ef58c363 | ||
|   | 994adea3b4 | ||
|   | fc61cc9be5 | ||
|   | 22d7338878 | ||
|   | 3f28515706 | ||
|   | 7d73ce9caf | ||
|   | 0758bc961c | ||
|   | 8472b9ae8a | ||
|   | 530a018e80 | ||
|   | 0b594afb4e | ||
|   | 9d0ea957bb | ||
|   | 8067785ac4 | ||
|   | 6ffc8a8388 | 
							
								
								
									
										43
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| name: build | ||||
|  | ||||
| on: [push, pull_request] | ||||
|  | ||||
| jobs: | ||||
|  | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|     - uses: actions/checkout@v2 | ||||
|  | ||||
|     - name: Set up Go | ||||
|       uses: actions/setup-go@v2 | ||||
|       with: | ||||
|         go-version: 1.19 | ||||
|  | ||||
|     - name: Vet | ||||
|       run: go vet ./... | ||||
|      | ||||
|     - name: Test | ||||
|       run: go test -race ./... && echo true | ||||
|        | ||||
|  | ||||
|   coverage: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Install Go | ||||
|         if: success() | ||||
|         uses: actions/setup-go@v2 | ||||
|         with: | ||||
|           go-version: 1.19.x | ||||
|       - name: Checkout code | ||||
|         uses: actions/checkout@v2 | ||||
|       - name: Calc coverage | ||||
|         run: | | ||||
|           go test -v -covermode=count -coverprofile=coverage.out ./... | ||||
|       - name: Convert coverage.out to coverage.lcov | ||||
|         uses: jandelgado/gcov2lcov-action@v1.0.6 | ||||
|       - name: Coveralls | ||||
|         uses: coverallsapp/github-action@v1.1.2 | ||||
|         with: | ||||
|           github-token: ${{ secrets.github_token }} | ||||
|           path-to-lcov: coverage.lcov | ||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,2 +1,3 @@ | ||||
| cmd/mqtt | ||||
| .DS_Store | ||||
| .DS_Store | ||||
| *.db | ||||
|   | ||||
							
								
								
									
										103
									
								
								.golangci.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								.golangci.yml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| linters: | ||||
|   disable-all: false | ||||
|   fix: false # Fix found issues (if it's supported by the linter). | ||||
|   enable: | ||||
|     # - asasalint | ||||
|     # - asciicheck | ||||
|     # - bidichk | ||||
|     # - bodyclose | ||||
|     # - containedctx | ||||
|     # - contextcheck | ||||
|     #- cyclop | ||||
|     # - deadcode | ||||
|     - decorder | ||||
|     # - depguard | ||||
|     # - dogsled | ||||
|     # - dupl | ||||
|     - durationcheck | ||||
|     # - errchkjson | ||||
|     # - errname | ||||
|     - errorlint | ||||
|     # - execinquery | ||||
|     # - exhaustive | ||||
|     # - exhaustruct | ||||
|     # - exportloopref | ||||
|     #- forcetypeassert | ||||
|     #- forbidigo | ||||
|     #- funlen | ||||
|     #- gci | ||||
|     # - gochecknoglobals | ||||
|     # - gochecknoinits | ||||
|     # - gocognit | ||||
|     # - goconst | ||||
|     # - gocritic | ||||
|     - gocyclo | ||||
|     - godot | ||||
|     # - godox | ||||
|     # - goerr113 | ||||
|     # - gofmt | ||||
|     # - gofumpt | ||||
|     # - goheader | ||||
|     - goimports | ||||
|     # - golint | ||||
|     # - gomnd | ||||
|     # - gomoddirectives | ||||
|     # - gomodguard | ||||
|     # - goprintffuncname | ||||
|     - gosec | ||||
|     - gosimple | ||||
|     - govet | ||||
|     # - grouper | ||||
|     # - ifshort | ||||
|     - importas | ||||
|     - ineffassign | ||||
|     # - interfacebloat | ||||
|     # - interfacer | ||||
|     # - ireturn | ||||
|     # - lll | ||||
|     # - maintidx | ||||
|     # - makezero | ||||
|     - maligned | ||||
|     - misspell | ||||
|     # - nakedret | ||||
|     # - nestif | ||||
|     # - nilerr | ||||
|     # - nilnil | ||||
|     # - nlreturn | ||||
|     # - noctx | ||||
|     # - nolintlint | ||||
|     # - nonamedreturns | ||||
|     # - nosnakecase | ||||
|     # - nosprintfhostport | ||||
|     # - paralleltest | ||||
|     # - prealloc | ||||
|     # - predeclared | ||||
|     # - promlinter | ||||
|     - reassign | ||||
|     # - revive | ||||
|     # - rowserrcheck | ||||
|     # - scopelint | ||||
|     # - sqlclosecheck | ||||
|     # - staticcheck | ||||
|     # - structcheck | ||||
|     # - stylecheck | ||||
|     # - tagliatelle | ||||
|     # - tenv | ||||
|     # - testpackage | ||||
|     # - thelper | ||||
|     - tparallel | ||||
|     # - typecheck | ||||
|     - unconvert | ||||
|     - unparam | ||||
|     - unused | ||||
|     - usestdlibvars | ||||
|     # - varcheck | ||||
|     # - varnamelen | ||||
|     - wastedassign | ||||
|     - whitespace | ||||
|     # - wrapcheck | ||||
|     # - wsl | ||||
|   disable: | ||||
|     - errcheck | ||||
|  | ||||
|  | ||||
							
								
								
									
										18
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								.travis.yml
									
									
									
									
									
								
							| @@ -1,18 +0,0 @@ | ||||
| dist: xenial | ||||
|  | ||||
| language: go | ||||
|  | ||||
| env: | ||||
|   - GO111MODULE=on | ||||
|  | ||||
| go: | ||||
|   - 1.13.x | ||||
|  | ||||
| git: | ||||
|   depth: 1 | ||||
|    | ||||
| script: | ||||
|   - go test -v -race ./... -coverprofile=coverage.txt -covermode=atomic | ||||
|  | ||||
| after_success: | ||||
|   - bash <(curl -s https://codecov.io/bash) | ||||
							
								
								
									
										31
									
								
								Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								Dockerfile
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| FROM golang:1.19.0-alpine3.15 AS builder | ||||
|  | ||||
| RUN apk update | ||||
| RUN apk add git | ||||
|  | ||||
| WORKDIR /app | ||||
|  | ||||
| COPY go.mod ./ | ||||
| COPY go.sum ./ | ||||
| RUN go mod download | ||||
|  | ||||
| COPY . ./ | ||||
|  | ||||
| RUN go build -o /app/mochi ./cmd | ||||
|  | ||||
|  | ||||
| FROM alpine | ||||
|  | ||||
| WORKDIR / | ||||
| COPY --from=builder /app/mochi . | ||||
|  | ||||
| # tcp | ||||
| EXPOSE 1883 | ||||
|  | ||||
| # websockets | ||||
| EXPOSE 1882 | ||||
|  | ||||
| # dashboard | ||||
| EXPOSE 8080 | ||||
|  | ||||
| ENTRYPOINT [ "/mochi" ] | ||||
| @@ -1,7 +1,7 @@ | ||||
|  | ||||
| The MIT License (MIT) | ||||
|  | ||||
| Copyright (c) 2019 Jonathan Blake (mochi) | ||||
| Copyright (c) 2019, 2022 Jonathan Blake (mochi-co) | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
|   | ||||
							
								
								
									
										481
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										481
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,228 +1,391 @@ | ||||
|  | ||||
| <p align="center"> | ||||
|    | ||||
| [](https://travis-ci.com/mochi-co/mqtt) | ||||
|  | ||||
|   | ||||
| [](https://coveralls.io/github/mochi-co/mqtt?branch=master) | ||||
| [](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2) | ||||
| [](https://pkg.go.dev/github.com/mochi-co/mqtt/v2) | ||||
| [](https://github.com/mochi-co/mqtt/issues) | ||||
| [](https://codecov.io/gh/mochi-co/mqtt) | ||||
| [](https://pkg.go.dev/github.com/mochi-co/mqtt) | ||||
|  | ||||
| </p> | ||||
|  | ||||
| # Mochi MQTT  | ||||
| ### A High-performance MQTT server in Go (v3.0 | v3.1.1)  | ||||
| # Mochi MQTT Broker | ||||
| ## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server  | ||||
| Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.  | ||||
|  | ||||
| Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with everyone's favourites such as Mosquitto, Mosca, and VerneMQ. | ||||
| ### What is MQTT? | ||||
| MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol. | ||||
|  | ||||
| #### What is MQTT? | ||||
| MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq) | ||||
| ## What's new in Version 2.0.0? | ||||
| Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.  | ||||
|  | ||||
| #### Mochi MQTT Features | ||||
| - Paho MQTT 3.0 / 3.1.1 compatible.  | ||||
| - Full MQTT Feature-set (QoS, Retained, $SYS) | ||||
| - Trie-based Subscription model. | ||||
| - Ring Buffer packet codec. | ||||
| - TCP, Websocket, (including SSL/TLS) and Dashboard listeners. | ||||
| - Interfaces for Client Authentication and Topic access control. | ||||
| - Bolt-backed persistence and storage interfaces. | ||||
| - Directly Publishing from embedding service (`s.Publish(topic, message, retain)`). | ||||
| - Basic Event Hooks (currently `onMessage`) | ||||
| Don't forget to use the new v2 import paths: | ||||
| ```go | ||||
| import "github.com/mochi-co/mqtt/v2" | ||||
| ``` | ||||
|  | ||||
| #### Roadmap | ||||
| - MQTT v5 compatibility? | ||||
| - Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0: | ||||
|     - User and MQTTv5 Packet Properties | ||||
|     - Topic Aliases | ||||
|     - Shared Subscriptions | ||||
|     - Subscription Options and Subscription Identifiers | ||||
|     - Message Expiry | ||||
|     - Client Session Expiry | ||||
|     - Send and Receive QoS Flow Control Quotas | ||||
|     - Server-side Disconnect and Auth Packets | ||||
|     - Will Delay Intervals | ||||
|     - Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.  | ||||
| - Developer-centric: | ||||
|     - Most core broker code is now exported and accessible, for total developer control. | ||||
|     - Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development. | ||||
|     - Direct Packet Injection using special inline client, or masquerade as existing clients. | ||||
| - Performant and Stable: | ||||
|     - Our classic trie-based Topic-Subscription model. | ||||
|     - A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.  | ||||
|     - Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3. | ||||
|     - Over a thousand carefully considered unit test scenarios. | ||||
| - TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners. | ||||
| - Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own). | ||||
| - Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own). | ||||
|  | ||||
| #### Using the Broker | ||||
| Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon. | ||||
| > There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system). | ||||
|  | ||||
| ### Compatibility Notes | ||||
| Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties). | ||||
|  | ||||
| Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits. | ||||
|  | ||||
| ## Roadmap | ||||
| - Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks! | ||||
| - Cluster support. | ||||
| - Enhanced Metrics support. | ||||
| - File-based server configuration (supporting docker). | ||||
|  | ||||
| ## Quick Start | ||||
| ### Running the Broker with Go | ||||
| Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the [cmd/main.go](cmd/main.go) entrypoint in the [cmd](cmd) folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. | ||||
|  | ||||
| ``` | ||||
| cd cmd | ||||
| go build -o mqtt && ./mqtt | ||||
| ``` | ||||
|  | ||||
| #### Quick Start | ||||
| ### Using Docker | ||||
| A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server: | ||||
|  | ||||
| ```sh | ||||
| docker build -t mochi:latest . | ||||
| docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest | ||||
| ``` | ||||
|  | ||||
| ## Developing with Mochi MQTT | ||||
| ### Importing as a package | ||||
| Importing Mochi MQTT as a package requires just a few lines of code to get started. | ||||
| ``` go | ||||
| import ( | ||||
|   mqtt "github.com/mochi-co/mqtt/server" | ||||
|   "github.com/mochi-co/mqtt/v2" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
|     // Create the new MQTT Server. | ||||
|     server := mqtt.New() | ||||
| 	 | ||||
|     // Create a TCP listener on a standard port. | ||||
|     tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	 | ||||
|     // Add the listener to the server with default options (nil). | ||||
|     err := server.AddListener(tcp, nil) | ||||
|     if err != nil { | ||||
|         log.Fatal(err) | ||||
|     } | ||||
| 	 | ||||
|     // Start the broker. Serve() is blocking - see examples folder  | ||||
|     // for usage ideas. | ||||
|     err = server.Serve() | ||||
|     if err != nil { | ||||
|         log.Fatal(err) | ||||
|     } | ||||
|   // Create the new MQTT Server. | ||||
|   server := mqtt.New(nil) | ||||
|  | ||||
|   // Allow all connections. | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
|   // Create a TCP listener on a standard port. | ||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr, nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|    | ||||
|   err = server.Serve() | ||||
|   if err != nil { | ||||
|     log.Fatal(err) | ||||
|   } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| Examples of running the broker with various configurations can be found in the `examples` folder.  | ||||
| Examples of running the broker with various configurations can be found in the [examples](examples) folder.  | ||||
|  | ||||
| #### Network Listeners | ||||
| The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are: | ||||
| - `listeners.NewTCP(id, address string)` - A TCP Listener, taking a unique ID and a network address to bind. | ||||
| - `listeners.NewWebsocket(id, address string)` A Websocket Listener | ||||
| - `listeners.NewHTTPStats()` An HTTP $SYS info dashboard | ||||
|  | ||||
| ##### Configuring Network Listeners | ||||
| When a listener is added to the server using `server.AddListener`, a `*listeners.Config` may be passed as the second argument. | ||||
| - `listeners.NewTCP(...)` - A TCP listener. | ||||
| - `listeners.NewWebsocket(...)` A Websocket listener. | ||||
| - `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard. | ||||
| - Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know! | ||||
|  | ||||
| ##### Authentication and ACL | ||||
| Authentication and ACL may be configured on a per-listener basis by providing an Auth Controller to the listener configuration. Custom Auth Controllers should satisfy the `auth.Controller` interface found in `listeners/auth`. Two default controllers are provided, `auth.Allow`, which allows all traffic, and `auth.Disallow`, which denies all traffic.  | ||||
| A `*listeners.Config` may be passed to configure TLS.  | ||||
|  | ||||
| Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go). | ||||
|  | ||||
| ### Server Options and Capabilities | ||||
| A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server. | ||||
|  | ||||
| ```go | ||||
| err := server.AddListener(tcp, &listeners.Config{ | ||||
| 	Auth: new(auth.Allow), | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| > If no auth controller is provided in the listener configuration, the server will default to _Disallowing_ all traffic to prevent unintentional security issues. | ||||
|  | ||||
| ##### SSL | ||||
| SSL may be configured on both the TCP and Websocket listeners by providing a public-private PEM key pair to the listener configuration as `[]byte` slices. | ||||
| ```go | ||||
| err := server.AddListener(tcp, &listeners.Config{ | ||||
|     Auth: new(auth.Allow), | ||||
|     TLS: &listeners.TLS{ | ||||
|         Certificate: publicCertificate,  | ||||
|         PrivateKey:  privateKey, | ||||
| server := mqtt.New(&mqtt.Options{ | ||||
|   Capabilities: mqtt.Capabilities{ | ||||
|     MaximumSessionExpiryInterval: 3600, | ||||
|     Compatibilities: mqtt.Compatibilities{ | ||||
|       ObscureNotAuthorized: true, | ||||
|     }, | ||||
|   }, | ||||
|   SysTopicResendInterval: 10, | ||||
| }) | ||||
| ``` | ||||
| > Note the mandatory inclusion of the Auth Controller! | ||||
|  | ||||
| #### Event Hooks | ||||
| Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service. | ||||
| Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. | ||||
|  | ||||
| ##### OnMessage | ||||
| `server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns. | ||||
|  | ||||
| > This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method. | ||||
| ## Event Hooks  | ||||
| A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools. | ||||
|  | ||||
| Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code. | ||||
|  | ||||
| | Type | Import | Info | | ||||
| | -- | -- |  -- | | ||||
| | Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to  all topics. |  | ||||
| | Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger.  |  | ||||
| | Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go)  | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |  | ||||
| | Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |  | ||||
| | Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go)  | Persistent storage using [Redis](https://redis.io). |  | ||||
| | Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |  | ||||
|  | ||||
| Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know! | ||||
|  | ||||
| ### Access Control  | ||||
| #### Allow Hook | ||||
| By default, Mochi MQTT uses a DENY-ALL access control rule. To allow connections, this must overwritten using an Access Control hook. The simplest of these hooks is the `auth.AllowAll` hook, which provides ALLOW-ALL rules to all connections, subscriptions, and publishing. It's also the simplest hook to use: | ||||
|  | ||||
| ```go | ||||
| import "github.com/mochi-co/mqtt/server/events" | ||||
|  | ||||
| server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) { | ||||
|     if string(pk.Payload) == "hello" { | ||||
|         pkx = pk | ||||
|         pkx.Payload = []byte("hello world") | ||||
|         return pkx, nil | ||||
|     }  | ||||
|      | ||||
|     return pk, nil | ||||
| } | ||||
| server := mqtt.New(nil) | ||||
| _ = server.AddHook(new(auth.AllowHook), nil) | ||||
| ``` | ||||
|  | ||||
| A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in! | ||||
| > Don't do this if you are exposing your server to the internet or untrusted networks - it should really be used for development, testing, and debugging only. | ||||
|  | ||||
| #### Direct Publishing | ||||
| When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported. | ||||
| #### Auth Ledger | ||||
| The Auth Ledger hook provides a sophisticated mechanism for defining access rules in a struct format. Auth ledger rules come in two forms: Auth rules (connection), and ACL rules (publish subscribe).  | ||||
|  | ||||
| ```go  | ||||
| // func (s *Server) Publish(topic string, payload []byte, retain bool) error | ||||
| err := s.Publish("a/b/c", []byte("hello"), false) | ||||
| if err != nil { | ||||
|     log.Fatal(err) | ||||
| } | ||||
| ``` | ||||
| Auth rules have 4 optional criteria and an assertion flag: | ||||
| | Criteria | Usage |  | ||||
| | -- | -- | | ||||
| | Client | client id of the connecting client | | ||||
| | Username | username of the connecting client | | ||||
| | Password | password of the connecting client | | ||||
| | Remote | the remote address or ip of the client | | ||||
| | Allow | true (allow this user) or false (deny this user) |  | ||||
|  | ||||
| A working example can be found in the `examples/events` folder. | ||||
| ACL rules have 3 optional criteria and an filter match: | ||||
| | Criteria | Usage |  | ||||
| | -- | -- | | ||||
| | Client | client id of the connecting client | | ||||
| | Username | username of the connecting client | | ||||
| | Remote | the remote address or ip of the client | | ||||
| | Filters | an array of filters to match | | ||||
|  | ||||
| Rules are processed in index order (0,1,2,3), returning on the first matching rule. See [hooks/auth/ledger.go](hooks/auth/ledger.go) to review the structs. | ||||
|  | ||||
| #### Data Persistence | ||||
| Mochi MQTT provides a `persistence.Store` interface for developing and attaching persistent stores to the broker. The default persistence mechanism packaged with the broker is backed by [Bolt](https://github.com/etcd-io/bbolt) and can be enabled by assigning a `*bolt.Store` to the server. | ||||
| ```go | ||||
| // import "github.com/mochi-co/mqtt/server/persistence/bolt" | ||||
| err = server.AddStore(bolt.New("mochi.db", nil)) | ||||
| server := mqtt.New(nil) | ||||
| err := server.AddHook(new(auth.Hook), &auth.Options{ | ||||
|     Ledger: &auth.Ledger{ | ||||
|     Auth: auth.AuthRules{ // Auth disallows all by default | ||||
|       {Username: "peach", Password: "password1", Allow: true}, | ||||
|       {Username: "melon", Password: "password2", Allow: true}, | ||||
|       {Remote: "127.0.0.1:*", Allow: true}, | ||||
|       {Remote: "localhost:*", Allow: true}, | ||||
|     }, | ||||
|     ACL: auth.ACLRules{ // ACL allows all by default | ||||
|       {Remote: "127.0.0.1:*"}, // local superuser allow all | ||||
|       { | ||||
|         // user melon can read and write to their own topic | ||||
|         Username: "melon", Filters: auth.Filters{ | ||||
|           "melon/#":   auth.ReadWrite, | ||||
|           "updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others | ||||
|         }, | ||||
|       }, | ||||
|       { | ||||
|         // Otherwise, no clients have publishing permissions | ||||
|         Filters: auth.Filters{ | ||||
|           "#":         auth.ReadOnly, | ||||
|           "updates/#": auth.Deny, | ||||
|         }, | ||||
|       }, | ||||
|     }, | ||||
|   } | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| The ledger can also be stored as JSON or YAML and loaded using the Data field: | ||||
| ```go | ||||
| err = server.AddHook(new(auth.Hook), &auth.Options{ | ||||
|     Data: data, // build ledger from byte slice: yaml or json | ||||
| }) | ||||
| ``` | ||||
| See [examples/auth/encoded/main.go](examples/auth/encoded/main.go) for more information. | ||||
|  | ||||
| ### Persistent Storage  | ||||
| #### Redis | ||||
| A basic Redis storage hook is available which provides persistence for the broker. It can be added to the server in the same fashion as any other hook, with several options. It uses github.com/go-redis/redis/v8 under the hook, and is completely configurable through the Options value.  | ||||
| ```go | ||||
| err := server.AddHook(new(redis.Hook), &redis.Options{ | ||||
|   Options: &rv8.Options{ | ||||
|     Addr:     "localhost:6379", // default redis address | ||||
|     Password: "",               // your password | ||||
|     DB:       0,                // your redis db | ||||
|   }, | ||||
| }) | ||||
| if err != nil { | ||||
|     log.Fatal(err) | ||||
|   log.Fatal(err) | ||||
| } | ||||
| ``` | ||||
| > Persistence is on-demand (not flushed) and will potentially reduce throughput when compared to the standard in-memory store. Only use it if you need to maintain state through restarts. | ||||
| For more information on how the redis hook works, or how to use it, see the [examples/persistence/redis/main.go](examples/persistence/redis/main.go) or [hooks/storage/redis](hooks/storage/redis) code. | ||||
|  | ||||
| #### Badger DB | ||||
| There's also a BadgerDB storage hook if you prefer file based storage. It can be added and configured in much the same way as the other hooks (with somewhat less options). | ||||
| ```go | ||||
| err := server.AddHook(new(badger.Hook), &badger.Options{ | ||||
|   Path: badgerPath, | ||||
| }) | ||||
| if err != nil { | ||||
|   log.Fatal(err) | ||||
| } | ||||
| ``` | ||||
| For more information on how the badger hook works, or how to use it, see the [examples/persistence/badger/main.go](examples/persistence/badger/main.go) or [hooks/storage/badger](hooks/storage/badger) code. | ||||
|  | ||||
| There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go). | ||||
|  | ||||
|  | ||||
|  | ||||
| ## Developing with Event Hooks | ||||
| Many hooks are available for interacting with the broker and client lifecycle.  | ||||
| The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go). | ||||
|  | ||||
| > The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets. | ||||
|  | ||||
| | Function | Usage |  | ||||
| | -------------------------- | -- | | ||||
| | OnStarted | Called when the server has successfully started.| | ||||
| | OnStopped | Called when the server has successfully stopped. |  | ||||
| | OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. | | ||||
| | OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. | | ||||
| | OnSysInfoTick | Called when the $SYS topic values are published out. | | ||||
| | OnConnect | Called when a new client connects |  | ||||
| | OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |  | ||||
| | OnDisconnect | Called when a client is disconnected for any reason. |  | ||||
| | OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |  | ||||
| | OnPacketRead | Called when a packet is received from a client. Allows packet modification. |  | ||||
| | OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |  | ||||
| | OnPacketSent | Called when a packet has been sent to a client. |  | ||||
| | OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |  | ||||
| | OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |  | ||||
| | OnSubscribed | Called when a client successfully subscribes to one or more filters. |  | ||||
| | OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|  | ||||
| | OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |  | ||||
| | OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |  | ||||
| | OnPublish | Called when a client publishes a message. Allows packet modification. |  | ||||
| | OnPublished | Called when a client has published a message to subscribers. |  | ||||
| | OnRetainMessage | Called then a published message is retained. |  | ||||
| | OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |  | ||||
| | OnQosComplete | Called when the Qos flow for a message has been completed. |  | ||||
| | OnQosDropped | Called when an inflight message expires before completion. |  | ||||
| | OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |  | ||||
| | OnWillSent | Called when an LWT message has been issued from a disconnecting client. |  | ||||
| | OnClientExpired | Called when a client session has expired and should be deleted. |  | ||||
| | OnRetainedExpired | Called when a retained message has expired and should be deleted. |  | ||||
| | OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|  | ||||
| | StoredClients |  Returns clients, eg. from a persistent store. |  | ||||
| | StoredSubscriptions |  Returns client subscriptions, eg. from a persistent store. |  | ||||
| | StoredInflightMessages | Returns inflight messages, eg. from a persistent store.  |  | ||||
| | StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |  | ||||
| | StoredSysInfo | Returns stored system info values, eg. from a persistent store. |  | ||||
|  | ||||
| If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`. | ||||
|  | ||||
|  | ||||
| ### Direct Publish | ||||
| To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method. | ||||
|  | ||||
| ```go | ||||
| err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) | ||||
| ``` | ||||
| > The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec. | ||||
|  | ||||
| ### Packet Injection | ||||
| If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.  | ||||
|  | ||||
| Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements). | ||||
|  | ||||
| ```go | ||||
| cl := server.NewClient(nil, "local", "inline", true) | ||||
| server.InjectPacket(cl, packets.Packet{ | ||||
|   FixedHeader: packets.FixedHeader{ | ||||
|     Type: packets.Publish, | ||||
|   }, | ||||
|   TopicName: "direct/publish", | ||||
|   Payload: []byte("scheduled message"), | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| > MQTT packets still need to be correctly formed, so refer our [the test packets catalogue](packets/tpackets.go) and [MQTTv5 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) for inspiration. | ||||
|  | ||||
| See the [hooks example](examples/hooks/main.go) to see this feature in action. | ||||
|  | ||||
| ### Testing | ||||
| #### Unit Tests | ||||
| Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go: | ||||
| ``` | ||||
| go run --cover ./... | ||||
| ``` | ||||
|  | ||||
| #### Paho Interoperability Test | ||||
| You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder. | ||||
| You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the mqtt v5 and v3 tests with `python3 client_test5.py` from the _interoperability_ folder.  | ||||
|  | ||||
| > Note that there are currently a number of outstanding issues regarding false negatives in the paho suite, and as such, certain compatibility modes are enabled in the `paho/main.go` example. | ||||
|  | ||||
|  | ||||
| #### Performance (messages/second) | ||||
| Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a  13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. SEND = Publish throughput, RECV = Subscribe throughput. | ||||
| ## Performance Benchmarks | ||||
| Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQX, and others. | ||||
|  | ||||
| > As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers. | ||||
| Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. | ||||
|  | ||||
| **Single Client, 10,000 messages** | ||||
| _With only 1 client, there is no variation in throughput so the benchmark is reports the same number for high, low, and median._ | ||||
| > The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers. | ||||
| > Benchmarks are provided as a general performance expectation guideline only. | ||||
|  | ||||
|  | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=1 -num-messages=10000` | ||||
|  | ||||
| |              | Mochi     | Mosquitto   | EMQX     | VerneMQ   | Mosca   |   | ||||
| | :----------- | --------: | ----------: | -------: | --------: | --------: | ||||
| | SEND Max    | 36505  |   30597  | 27202  | 32782  | 30125   | | ||||
| | SEND Min    |  36505    |  30597  | 27202   |  32782  | 30125  | | ||||
| | SEND Median  | 36505   | 30597   | 27202   |32782    | 30125  | | ||||
| | RECV Max    | 152221  |  59130  | 7879   | 17551   | 9145   | | ||||
| | RECV Min    | 152221  | 59130   | 7879   |  17551    |  9145    | | ||||
| | RECV Median    | 152221  |  59130  | 7879   |  17551   |  9145   | | ||||
|  | ||||
| **10 Clients, 1,000 Messages** | ||||
|  | ||||
|  | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=1000` | ||||
|  | ||||
| |              | Mochi     | Mosquitto   | EMQX     | VerneMQ   | Mosca   |   | ||||
| | :----------- | --------: | ----------: | -------: | --------: | --------: | ||||
| | SEND Max    |  37193 | 	15775 |	17455 |	34138 |	36575  | | ||||
| | SEND Min    |   6529 |	6446 |	7714 |	8583 |	7383      | | ||||
| | SEND Median  |  15127 |	7813 | 	10305 |	9887 |	8169     | | ||||
| | RECV Max    |  33535	 | 3710	| 3022 |	4534 |	9411    | | ||||
| | RECV Min    |   7484	| 2661	| 1689 |	2021 |	2275     | | ||||
| | RECV Median    |   11427 |  3142 | 1831 |	2468 |	4692      | | ||||
|  | ||||
| **10 Clients, 10,000 Messages** | ||||
|  | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 | | ||||
| | Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 | | ||||
| | EMQX v5.0.11      | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 | | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 | | ||||
| | Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 | | ||||
| | EMQX v5.0.11      | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 | | ||||
|  | ||||
| |              | Mochi     | Mosquitto   | EMQX     | VerneMQ   | Mosca   |   | ||||
| | :----------- | --------: | ----------: | -------: | --------: | --------: | ||||
| | SEND Max    |   13153 |	13270 |	12229 |	13025 |	38446  | | ||||
| | SEND Min    |  8728	| 8513	| 8193 | 	6483 |	3889    | | ||||
| | SEND Median  |   9045	| 9532	| 9252 |	8031 |	9210    | | ||||
| | RECV Max    |  20774	| 5052	| 2093 |	2071 | 	43008    | | ||||
| | RECV Min    |   10718	 |3995	| 1531	| 1673	| 18764   | | ||||
| | RECV Median    |  16339 |	4607 |	1620 | 	1907	| 33524  | | ||||
| Million Message Challenge (hit the server with 1 million messages immediately): | ||||
|  | ||||
| **500 Clients, 100 Messages** | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000` | ||||
| | Broker            | publish fastest | median | slowest | receive fastest | median | slowest |  | ||||
| | --                | --             | --   | --   | --             | --   | --   | | ||||
| | Mochi v2.0.0      | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636  | | ||||
| | Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 | | ||||
| | EMQX v5.0.11      | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 | | ||||
|  | ||||
|  | ||||
|  | ||||
| `mqtt-stresser -broker tcp://localhost:1883 -num-clients=500 -num-messages=100` | ||||
|  | ||||
| |              | Mochi     | Mosquitto   | EMQX     | VerneMQ   | Mosca   |   | ||||
| | :----------- | --------: | ----------: | -------: | --------: | --------: | ||||
| | SEND Max    |  70688	| 72686	| 71392 |	75336 |	73192   | | ||||
| | SEND Min    |   1021	| 2577 |	1603 |	8417 |	2344  | | ||||
| | SEND Median  |  49871	| 33076 |	33637 |	35200 |	31312   | | ||||
| | RECV Max    |  116163 |	4215 |	3427 |	5484 |	10100 | | ||||
| | RECV Min    |   1044	| 156 | 	56 | 	83	| 169   | | ||||
| | RECV Median    |     24398 | 208 |	94 |	413 |	474     | | ||||
| > Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software. | ||||
|  | ||||
| ## Stargazers over time 🥰 | ||||
| [](https://starchart.cc/mochi-co/mqtt) | ||||
| Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues) | ||||
|  | ||||
| ## Contributions | ||||
| Contributions and feedback are both welcomed and encouraged! Open an [issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request. | ||||
| Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request. | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 39 KiB | 
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 40 KiB | 
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 38 KiB | 
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 37 KiB | 
							
								
								
									
										544
									
								
								clients.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										544
									
								
								clients.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,544 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 J. Blake / mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/rs/xid" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	defaultKeepalive             uint16 = 10 // the default connection keepalive value in seconds | ||||
| 	defaultClientProtocolVersion byte   = 4  // the default mqtt protocol version of connecting clients (if somehow unspecified). | ||||
| ) | ||||
|  | ||||
| // ReadFn is the function signature for the function used for reading and processing new packets. | ||||
| type ReadFn func(*Client, packets.Packet) error | ||||
|  | ||||
| // Clients contains a map of the clients known by the broker. | ||||
| type Clients struct { | ||||
| 	internal map[string]*Client // clients known by the broker, keyed on client id. | ||||
| 	sync.RWMutex | ||||
| } | ||||
|  | ||||
| // NewClients returns an instance of Clients. | ||||
| func NewClients() *Clients { | ||||
| 	return &Clients{ | ||||
| 		internal: make(map[string]*Client), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Add adds a new client to the clients map, keyed on client id. | ||||
| func (cl *Clients) Add(val *Client) { | ||||
| 	cl.Lock() | ||||
| 	defer cl.Unlock() | ||||
| 	cl.internal[val.ID] = val | ||||
| } | ||||
|  | ||||
| // GetAll returns all the clients. | ||||
| func (cl *Clients) GetAll() map[string]*Client { | ||||
| 	cl.RLock() | ||||
| 	defer cl.RUnlock() | ||||
| 	m := map[string]*Client{} | ||||
| 	for k, v := range cl.internal { | ||||
| 		m[k] = v | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| // Get returns the value of a client if it exists. | ||||
| func (cl *Clients) Get(id string) (*Client, bool) { | ||||
| 	cl.RLock() | ||||
| 	defer cl.RUnlock() | ||||
| 	val, ok := cl.internal[id] | ||||
| 	return val, ok | ||||
| } | ||||
|  | ||||
| // Len returns the length of the clients map. | ||||
| func (cl *Clients) Len() int { | ||||
| 	cl.RLock() | ||||
| 	defer cl.RUnlock() | ||||
| 	val := len(cl.internal) | ||||
| 	return val | ||||
| } | ||||
|  | ||||
| // Delete removes a client from the internal map. | ||||
| func (cl *Clients) Delete(id string) { | ||||
| 	cl.Lock() | ||||
| 	defer cl.Unlock() | ||||
| 	delete(cl.internal, id) | ||||
| } | ||||
|  | ||||
| // GetByListener returns clients matching a listener id. | ||||
| func (cl *Clients) GetByListener(id string) []*Client { | ||||
| 	cl.RLock() | ||||
| 	defer cl.RUnlock() | ||||
| 	clients := make([]*Client, 0, cl.Len()) | ||||
| 	for _, client := range cl.internal { | ||||
| 		if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 { | ||||
| 			clients = append(clients, client) | ||||
| 		} | ||||
| 	} | ||||
| 	return clients | ||||
| } | ||||
|  | ||||
| // Client contains information about a client known by the broker. | ||||
| type Client struct { | ||||
| 	Properties   ClientProperties // client properties | ||||
| 	State        ClientState      // the operational state of the client. | ||||
| 	Net          ClientConnection // network connection state of the clinet | ||||
| 	ID           string           // the client id. | ||||
| 	ops          *ops             // ops provides a reference to server ops. | ||||
| 	sync.RWMutex                  // mutex | ||||
| } | ||||
|  | ||||
| // ClientConnection contains the connection transport and metadata for the client. | ||||
| type ClientConnection struct { | ||||
| 	conn     net.Conn          // the net.Conn used to establish the connection | ||||
| 	bconn    *bufio.ReadWriter // a buffered net.Conn for reading packets | ||||
| 	Remote   string            // the remote address of the client | ||||
| 	Listener string            // listener id of the client | ||||
| 	Inline   bool              // client is an inline programmetic client | ||||
| } | ||||
|  | ||||
| // ClientProperties contains the properties which define the client behaviour. | ||||
| type ClientProperties struct { | ||||
| 	Props           packets.Properties | ||||
| 	Will            Will | ||||
| 	Username        []byte | ||||
| 	ProtocolVersion byte | ||||
| 	Clean           bool | ||||
| } | ||||
|  | ||||
| // Will contains the last will and testament details for a client connection. | ||||
| type Will struct { | ||||
| 	Payload           []byte                 // - | ||||
| 	User              []packets.UserProperty // - | ||||
| 	TopicName         string                 // - | ||||
| 	Flag              uint32                 // 0,1 | ||||
| 	WillDelayInterval uint32                 // - | ||||
| 	Qos               byte                   // - | ||||
| 	Retain            bool                   // - | ||||
| } | ||||
|  | ||||
| // State tracks the state of the client. | ||||
| type ClientState struct { | ||||
| 	TopicAliases  TopicAliases   // a map of topic aliases | ||||
| 	stopCause     atomic.Value   // reason for stopping | ||||
| 	Inflight      *Inflight      // a map of in-flight qos messages | ||||
| 	Subscriptions *Subscriptions // a map of the subscription filters a client maintains | ||||
| 	disconnected  int64          // the time the client disconnected in unix time, for calculating expiry | ||||
| 	endOnce       sync.Once      // only end once | ||||
| 	packetID      uint32         // the current highest packetID | ||||
| 	done          uint32         // atomic counter which indicates that the client has closed | ||||
| 	keepalive     uint16         // the number of seconds the connection can wait | ||||
| } | ||||
|  | ||||
| // newClient returns a new instance of Client. This is almost exclusively used by Server | ||||
| // for creating new clients, but it lives here because it's not dependent. | ||||
| func newClient(c net.Conn, o *ops) *Client { | ||||
| 	cl := &Client{ | ||||
| 		State: ClientState{ | ||||
| 			Inflight:      NewInflights(), | ||||
| 			Subscriptions: NewSubscriptions(), | ||||
| 			TopicAliases:  NewTopicAliases(o.capabilities.TopicAliasMaximum), | ||||
| 			keepalive:     defaultKeepalive, | ||||
| 		}, | ||||
| 		Properties: ClientProperties{ | ||||
| 			ProtocolVersion: defaultClientProtocolVersion, // default protocol version | ||||
| 		}, | ||||
| 		ops: o, | ||||
| 	} | ||||
|  | ||||
| 	if c != nil { | ||||
| 		cl.Net = ClientConnection{ | ||||
| 			conn:   c, | ||||
| 			bconn:  bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)), | ||||
| 			Remote: c.RemoteAddr().String(), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.State.keepalive) | ||||
|  | ||||
| 	return cl | ||||
| } | ||||
|  | ||||
| // ParseConnect parses the connect parameters and properties for a client. | ||||
| func (cl *Client) ParseConnect(lid string, pk packets.Packet) { | ||||
| 	cl.Net.Listener = lid | ||||
|  | ||||
| 	cl.Properties.ProtocolVersion = pk.ProtocolVersion | ||||
| 	cl.Properties.Username = pk.Connect.Username | ||||
| 	cl.Properties.Clean = pk.Connect.Clean | ||||
| 	cl.Properties.Props = pk.Properties.Copy(false) | ||||
|  | ||||
| 	cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client | ||||
| 	cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))    // client receive max | ||||
|  | ||||
| 	cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum) | ||||
|  | ||||
| 	cl.ID = pk.Connect.ClientIdentifier | ||||
| 	if cl.ID == "" { | ||||
| 		cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7] | ||||
| 		cl.Properties.Props.AssignedClientID = cl.ID | ||||
| 	} | ||||
|  | ||||
| 	cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive | ||||
| 	if pk.Connect.Keepalive > 0 { | ||||
| 		cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Connect.WillFlag { | ||||
| 		cl.Properties.Will = Will{ | ||||
| 			Qos:               pk.Connect.WillQos, | ||||
| 			Retain:            pk.Connect.WillRetain, | ||||
| 			Payload:           pk.Connect.WillPayload, | ||||
| 			TopicName:         pk.Connect.WillTopic, | ||||
| 			WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval, | ||||
| 			User:              pk.Connect.WillProperties.User, | ||||
| 		} | ||||
| 		if pk.Properties.SessionExpiryIntervalFlag && | ||||
| 			pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval { | ||||
| 			cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval | ||||
| 		} | ||||
| 		if pk.Connect.WillFlag { | ||||
| 			cl.Properties.Will.Flag = 1 // atomic for checking | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.State.keepalive) | ||||
| } | ||||
|  | ||||
| // refreshDeadline refreshes the read/write deadline for the net.Conn connection. | ||||
| func (cl *Client) refreshDeadline(keepalive uint16) { | ||||
| 	if cl.Net.conn != nil { | ||||
| 		var expiry time.Time // nil time can be used to disable deadline if keepalive = 0 | ||||
| 		if keepalive > 0 { | ||||
| 			expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22] | ||||
| 		} | ||||
|  | ||||
| 		_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22] | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NextPacketID returns the next available (unused) packet id for the client. | ||||
| // If no unused packet ids are available, an error is returned and the client | ||||
| // should be disconnected. | ||||
| func (cl *Client) NextPacketID() (i uint32, err error) { | ||||
| 	i = atomic.LoadUint32(&cl.State.packetID) | ||||
| 	started := i + 1 | ||||
| 	overflowed := false | ||||
| 	for { | ||||
| 		if i >= 65535 { | ||||
| 			overflowed = true | ||||
| 			i = 1 | ||||
| 		} else { | ||||
| 			i++ | ||||
| 		} | ||||
|  | ||||
| 		if overflowed && i == started { | ||||
| 			return 0, packets.ErrQuotaExceeded | ||||
| 		} | ||||
|  | ||||
| 		if _, ok := cl.State.Inflight.Get(uint16(i)); !ok { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	atomic.StoreUint32(&cl.State.packetID, i) | ||||
| 	return i, nil | ||||
| } | ||||
|  | ||||
| // ResendInflightMessages attempts to resend any pending inflight messages to connected clients. | ||||
| func (cl *Client) ResendInflightMessages(force bool) error { | ||||
| 	if cl.State.Inflight.Len() == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	for _, tk := range cl.State.Inflight.GetAll(false) { | ||||
| 		if tk.FixedHeader.Type == packets.Publish { | ||||
| 			tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3] | ||||
| 		} | ||||
|  | ||||
| 		//	cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends) | ||||
| 		err := cl.WritePacket(tk) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp { | ||||
| 			if ok := cl.State.Inflight.Delete(tk.PacketID); ok { | ||||
| 				cl.ops.hooks.OnQosComplete(cl, tk) | ||||
| 				atomic.AddInt64(&cl.ops.info.Inflight, -1) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session. | ||||
| func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 { | ||||
| 	var deleted int64 | ||||
| 	for _, tk := range cl.State.Inflight.GetAll(false) { | ||||
| 		if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now { | ||||
| 			if ok := cl.State.Inflight.Delete(tk.PacketID); ok { | ||||
| 				cl.ops.hooks.OnQosDropped(cl, tk) | ||||
| 				atomic.AddInt64(&cl.ops.info.Inflight, -1) | ||||
| 				deleted++ | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return deleted | ||||
| } | ||||
|  | ||||
| // Read reads incoming packets from the connected client and transforms them into | ||||
| // packets to be handled by the packetHandler. | ||||
| func (cl *Client) Read(packetHandler ReadFn) error { | ||||
| 	var err error | ||||
|  | ||||
| 	for { | ||||
| 		if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		cl.refreshDeadline(cl.State.keepalive) | ||||
| 		fh := new(packets.FixedHeader) | ||||
| 		err = cl.ReadFixedHeader(fh) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		pk, err := cl.ReadPacket(fh) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = packetHandler(cl, pk) // Process inbound packet. | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Stop instructs the client to shut down all processing goroutines and disconnect. | ||||
| func (cl *Client) Stop(err error) { | ||||
| 	if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	cl.State.endOnce.Do(func() { | ||||
| 		if cl.Net.conn != nil { | ||||
| 			_ = cl.Net.conn.Close() // omit close error | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			cl.State.stopCause.Store(err) | ||||
| 		} | ||||
|  | ||||
| 		atomic.StoreUint32(&cl.State.done, 1) | ||||
| 		atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix()) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // StopCause returns the reason the client connection was stopped, if any. | ||||
| func (cl *Client) StopCause() error { | ||||
| 	if cl.State.stopCause.Load() == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return cl.State.stopCause.Load().(error) | ||||
| } | ||||
|  | ||||
| // ReadFixedHeader reads in the values of the next packet's fixed header. | ||||
| func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | ||||
| 	if cl.Net.bconn == nil { | ||||
| 		return ErrConnectionClosed | ||||
| 	} | ||||
|  | ||||
| 	b, err := cl.Net.bconn.ReadByte() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = fh.Decode(b) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var bu int | ||||
| 	fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if cl.ops.capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.capabilities.MaximumPacketSize { | ||||
| 		return packets.ErrPacketTooLarge // [MQTT-3.2.2-15] | ||||
| 	} | ||||
|  | ||||
| 	atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadPacket reads the remaining buffer into an MQTT packet. | ||||
| func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) { | ||||
| 	atomic.AddInt64(&cl.ops.info.PacketsReceived, 1) | ||||
|  | ||||
| 	pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding | ||||
| 	pk.FixedHeader = *fh | ||||
| 	p := make([]byte, pk.FixedHeader.Remaining) | ||||
| 	n, err := io.ReadFull(cl.Net.bconn, p) | ||||
| 	if err != nil { | ||||
| 		return pk, err | ||||
| 	} | ||||
|  | ||||
| 	atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n)) | ||||
|  | ||||
| 	// Decode the remaining packet values using a fresh copy of the bytes, | ||||
| 	// otherwise the next packet will change the data of this one. | ||||
| 	px := append([]byte{}, p[:]...) | ||||
| 	switch pk.FixedHeader.Type { | ||||
| 	case packets.Connect: | ||||
| 		err = pk.ConnectDecode(px) | ||||
| 	case packets.Disconnect: | ||||
| 		err = pk.DisconnectDecode(px) | ||||
| 	case packets.Connack: | ||||
| 		err = pk.ConnackDecode(px) | ||||
| 	case packets.Publish: | ||||
| 		err = pk.PublishDecode(px) | ||||
| 		if err == nil { | ||||
| 			atomic.AddInt64(&cl.ops.info.MessagesReceived, 1) | ||||
| 		} | ||||
| 	case packets.Puback: | ||||
| 		err = pk.PubackDecode(px) | ||||
| 	case packets.Pubrec: | ||||
| 		err = pk.PubrecDecode(px) | ||||
| 	case packets.Pubrel: | ||||
| 		err = pk.PubrelDecode(px) | ||||
| 	case packets.Pubcomp: | ||||
| 		err = pk.PubcompDecode(px) | ||||
| 	case packets.Subscribe: | ||||
| 		err = pk.SubscribeDecode(px) | ||||
| 	case packets.Suback: | ||||
| 		err = pk.SubackDecode(px) | ||||
| 	case packets.Unsubscribe: | ||||
| 		err = pk.UnsubscribeDecode(px) | ||||
| 	case packets.Unsuback: | ||||
| 		err = pk.UnsubackDecode(px) | ||||
| 	case packets.Pingreq: | ||||
| 	case packets.Pingresp: | ||||
| 	case packets.Auth: | ||||
| 		err = pk.AuthDecode(px) | ||||
| 	default: | ||||
| 		err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type) | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return pk, err | ||||
| 	} | ||||
|  | ||||
| 	pk, err = cl.ops.hooks.OnPacketRead(cl, pk) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // WritePacket encodes and writes a packet to the client. | ||||
| func (cl *Client) WritePacket(pk packets.Packet) error { | ||||
| 	if atomic.LoadUint32(&cl.State.done) == 1 { | ||||
| 		return ErrConnectionClosed | ||||
| 	} | ||||
|  | ||||
| 	if cl.Net.conn == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	defer cl.refreshDeadline(cl.State.keepalive) | ||||
| 	if pk.Expiry > 0 { | ||||
| 		pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6] | ||||
| 	} | ||||
|  | ||||
| 	pk.ProtocolVersion = cl.Properties.ProtocolVersion | ||||
| 	if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests | ||||
| 		pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize | ||||
| 	} | ||||
|  | ||||
| 	if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 { | ||||
| 		pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set | ||||
| 	} | ||||
|  | ||||
| 	if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo { | ||||
| 		pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode | ||||
| 	} | ||||
|  | ||||
| 	pk = cl.ops.hooks.OnPacketEncode(cl, pk) | ||||
|  | ||||
| 	var err error | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	switch pk.FixedHeader.Type { | ||||
| 	case packets.Connect: | ||||
| 		err = pk.ConnectEncode(buf) | ||||
| 	case packets.Connack: | ||||
| 		err = pk.ConnackEncode(buf) | ||||
| 	case packets.Publish: | ||||
| 		err = pk.PublishEncode(buf) | ||||
| 	case packets.Puback: | ||||
| 		err = pk.PubackEncode(buf) | ||||
| 	case packets.Pubrec: | ||||
| 		err = pk.PubrecEncode(buf) | ||||
| 	case packets.Pubrel: | ||||
| 		err = pk.PubrelEncode(buf) | ||||
| 	case packets.Pubcomp: | ||||
| 		err = pk.PubcompEncode(buf) | ||||
| 	case packets.Subscribe: | ||||
| 		err = pk.SubscribeEncode(buf) | ||||
| 	case packets.Suback: | ||||
| 		err = pk.SubackEncode(buf) | ||||
| 	case packets.Unsubscribe: | ||||
| 		err = pk.UnsubscribeEncode(buf) | ||||
| 	case packets.Unsuback: | ||||
| 		err = pk.UnsubackEncode(buf) | ||||
| 	case packets.Pingreq: | ||||
| 		err = pk.PingreqEncode(buf) | ||||
| 	case packets.Pingresp: | ||||
| 		err = pk.PingrespEncode(buf) | ||||
| 	case packets.Disconnect: | ||||
| 		err = pk.DisconnectEncode(buf) | ||||
| 	case packets.Auth: | ||||
| 		err = pk.AuthEncode(buf) | ||||
| 	default: | ||||
| 		err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize { | ||||
| 		return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25] | ||||
| 	} | ||||
|  | ||||
| 	nb := net.Buffers{buf.Bytes()} | ||||
| 	n, err := nb.WriteTo(cl.Net.conn) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	atomic.AddInt64(&cl.ops.info.BytesSent, n) | ||||
| 	atomic.AddInt64(&cl.ops.info.PacketsSent, 1) | ||||
| 	if pk.FixedHeader.Type == packets.Publish { | ||||
| 		atomic.AddInt64(&cl.ops.info.MessagesSent, 1) | ||||
| 	} | ||||
|  | ||||
| 	cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes()) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										723
									
								
								clients_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										723
									
								
								clients_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,723 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| const pkInfo = "packet type %v, %s" | ||||
|  | ||||
| var errClientStop = errors.New("test stop") | ||||
|  | ||||
| func newTestClient() (cl *Client, r net.Conn, w net.Conn) { | ||||
| 	r, w = net.Pipe() | ||||
|  | ||||
| 	cl = newClient(w, &ops{ | ||||
| 		info:  new(system.Info), | ||||
| 		hooks: new(Hooks), | ||||
| 		log:   &logger, | ||||
| 		capabilities: &Capabilities{ | ||||
| 			ReceiveMaximum:    10, | ||||
| 			TopicAliasMaximum: 10000, | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	cl.ID = "mochi" | ||||
| 	cl.State.Inflight.maximumSendQuota = 5 | ||||
| 	cl.State.Inflight.sendQuota = 5 | ||||
| 	cl.State.Inflight.maximumReceiveQuota = 10 | ||||
| 	cl.State.Inflight.receiveQuota = 10 | ||||
| 	cl.Properties.Props.TopicAliasMaximum = 0 | ||||
| 	cl.Properties.Props.RequestResponseInfo = 0x1 | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func TestNewInflights(t *testing.T) { | ||||
| 	require.NotNil(t, NewInflights().internal) | ||||
| } | ||||
|  | ||||
| func TestNewClients(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	require.NotNil(t, cl.internal) | ||||
| } | ||||
|  | ||||
| func TestClientsAdd(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1"}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| } | ||||
|  | ||||
| func TestClientsGet(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1"}) | ||||
| 	cl.Add(&Client{ID: "t2"}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| 	require.Contains(t, cl.internal, "t2") | ||||
|  | ||||
| 	client, ok := cl.Get("t1") | ||||
| 	require.Equal(t, true, ok) | ||||
| 	require.Equal(t, "t1", client.ID) | ||||
| } | ||||
|  | ||||
| func TestClientsGetAll(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1"}) | ||||
| 	cl.Add(&Client{ID: "t2"}) | ||||
| 	cl.Add(&Client{ID: "t3"}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| 	require.Contains(t, cl.internal, "t2") | ||||
| 	require.Contains(t, cl.internal, "t3") | ||||
|  | ||||
| 	clients := cl.GetAll() | ||||
| 	require.Len(t, clients, 3) | ||||
| } | ||||
|  | ||||
| func TestClientsLen(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1"}) | ||||
| 	cl.Add(&Client{ID: "t2"}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| 	require.Contains(t, cl.internal, "t2") | ||||
| 	require.Equal(t, 2, cl.Len()) | ||||
| } | ||||
|  | ||||
| func TestClientsDelete(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1"}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
|  | ||||
| 	cl.Delete("t1") | ||||
| 	_, ok := cl.Get("t1") | ||||
| 	require.Equal(t, false, ok) | ||||
| 	require.Nil(t, cl.internal["t1"]) | ||||
| } | ||||
|  | ||||
| func TestClientsGetByListener(t *testing.T) { | ||||
| 	cl := NewClients() | ||||
| 	cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}}) | ||||
| 	cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}}) | ||||
| 	require.Contains(t, cl.internal, "t1") | ||||
| 	require.Contains(t, cl.internal, "t2") | ||||
|  | ||||
| 	clients := cl.GetByListener("tcp1") | ||||
| 	require.NotEmpty(t, clients) | ||||
| 	require.Equal(t, 1, len(clients)) | ||||
| 	require.Equal(t, "tcp1", clients[0].Net.Listener) | ||||
| } | ||||
|  | ||||
| func TestNewClient(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	require.NotNil(t, cl) | ||||
| 	require.NotNil(t, cl.State.Inflight.internal) | ||||
| 	require.NotNil(t, cl.State.Subscriptions) | ||||
| 	require.NotNil(t, cl.State.TopicAliases) | ||||
| 	require.Equal(t, defaultKeepalive, cl.State.keepalive) | ||||
| 	require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion) | ||||
| 	require.NotNil(t, cl.Net.conn) | ||||
| 	require.NotNil(t, cl.Net.bconn) | ||||
| 	require.False(t, cl.Net.Inline) | ||||
| } | ||||
|  | ||||
| func TestClientParseConnect(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		ProtocolVersion: 4, | ||||
| 		Connect: packets.ConnectParams{ | ||||
| 			ProtocolName:     []byte{'M', 'Q', 'T', 'T'}, | ||||
| 			Clean:            true, | ||||
| 			Keepalive:        60, | ||||
| 			ClientIdentifier: "mochi", | ||||
| 			WillFlag:         true, | ||||
| 			WillTopic:        "lwt", | ||||
| 			WillPayload:      []byte("lol gg"), | ||||
| 			WillQos:          1, | ||||
| 			WillRetain:       false, | ||||
| 		}, | ||||
| 		Properties: packets.Properties{ | ||||
| 			ReceiveMaximum: uint16(5), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	cl.ParseConnect("tcp1", pk) | ||||
| 	require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) | ||||
| 	require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive) | ||||
| 	require.Equal(t, pk.Connect.Clean, cl.Properties.Clean) | ||||
| 	require.Equal(t, pk.Connect.ClientIdentifier, cl.ID) | ||||
| 	require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName) | ||||
| 	require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload) | ||||
| 	require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos) | ||||
| 	require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain) | ||||
| 	require.Equal(t, uint32(1), cl.Properties.Will.Flag) | ||||
| 	require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota) | ||||
| 	require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota) | ||||
| 	require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota) | ||||
| 	require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota) | ||||
| } | ||||
|  | ||||
| func TestClientParseConnectOverrideWillDelay(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		ProtocolVersion: 4, | ||||
| 		Connect: packets.ConnectParams{ | ||||
| 			ProtocolName:     []byte{'M', 'Q', 'T', 'T'}, | ||||
| 			Clean:            true, | ||||
| 			Keepalive:        60, | ||||
| 			ClientIdentifier: "mochi", | ||||
| 			WillFlag:         true, | ||||
| 			WillProperties: packets.Properties{ | ||||
| 				WillDelayInterval: 200, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Properties: packets.Properties{ | ||||
| 			SessionExpiryInterval:     100, | ||||
| 			SessionExpiryIntervalFlag: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	cl.ParseConnect("tcp1", pk) | ||||
| 	require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval) | ||||
| } | ||||
|  | ||||
| func TestClientParseConnectNoID(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.ParseConnect("tcp1", packets.Packet{}) | ||||
| 	require.NotEmpty(t, cl.ID) | ||||
| } | ||||
|  | ||||
| func TestClientNextPacketID(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(1), i) | ||||
|  | ||||
| 	i, err = cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(2), i) | ||||
| } | ||||
|  | ||||
| func TestClientNextPacketIDInUse(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	// skip over 2 | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||
|  | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(1), i) | ||||
|  | ||||
| 	i, err = cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(3), i) | ||||
|  | ||||
| 	// Skip over overflow | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 65535}) | ||||
| 	atomic.StoreUint32(&cl.State.packetID, 65534) | ||||
|  | ||||
| 	i, err = cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(1), i) | ||||
| } | ||||
|  | ||||
| func TestClientNextPacketIDExhausted(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	for i := 0; i <= 65535; i++ { | ||||
| 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) | ||||
| 	} | ||||
|  | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.Equal(t, uint32(0), i) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrQuotaExceeded) | ||||
| } | ||||
|  | ||||
| func TestClientNextPacketIDOverflow(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	cl.State.packetID = uint32(65534) | ||||
|  | ||||
| 	i, err := cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(65535), i) | ||||
|  | ||||
| 	i, err = cl.NextPacketID() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint32(1), i) | ||||
| } | ||||
|  | ||||
| func TestClientClearInflights(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	n := time.Now().Unix() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n}) | ||||
| 	require.Equal(t, 5, cl.State.Inflight.Len()) | ||||
|  | ||||
| 	cl.ClearInflights(n, 4) | ||||
| 	require.Equal(t, 2, cl.State.Inflight.Len()) | ||||
| } | ||||
|  | ||||
| func TestClientResendInflightMessages(t *testing.T) { | ||||
| 	pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback) | ||||
| 	cl, r, w := newTestClient() | ||||
|  | ||||
| 	cl.State.Inflight.Set(*pk1.Packet) | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
|  | ||||
| 	go func() { | ||||
| 		err := cl.ResendInflightMessages(true) | ||||
| 		require.NoError(t, err) | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		w.Close() | ||||
| 	}() | ||||
|  | ||||
| 	buf, err := io.ReadAll(r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 0, cl.State.Inflight.Len()) | ||||
| 	require.Equal(t, pk1.RawBytes, buf) | ||||
| } | ||||
|  | ||||
| func TestClientResendInflightMessagesWriteFailure(t *testing.T) { | ||||
| 	pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup) | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	r.Close() | ||||
|  | ||||
| 	cl.State.Inflight.Set(*pk1.Packet) | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
| 	err := cl.ResendInflightMessages(true) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, io.ErrClosedPipe) | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
| } | ||||
|  | ||||
| func TestClientResendInflightMessagesNoMessages(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	err := cl.ResendInflightMessages(true) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientRefreshDeadline(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.refreshDeadline(10) | ||||
| 	require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline? | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeader(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
|  | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{packets.Connect << 4, 0x00}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived)) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderDecodeError(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	go func() { | ||||
| 		r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderPacketOversized(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	cl.ops.capabilities.MaximumPacketSize = 2 | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	go func() { | ||||
| 		r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrPacketTooLarge) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderReadEOF(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	go func() { | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, io.EOF, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	go func() { | ||||
| 		r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadOK(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			packets.Publish << 4, 18, // Fixed header | ||||
| 			0, 5, // Topic Name - LSB+MSB | ||||
| 			'a', '/', 'b', '/', 'c', // Topic Name | ||||
| 			'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload, | ||||
| 			packets.Publish << 4, 11, // Fixed header | ||||
| 			0, 5, // Topic Name - LSB+MSB | ||||
| 			'd', '/', 'e', '/', 'f', // Topic Name | ||||
| 			'y', 'e', 'a', 'h', // Payload | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	var pks []packets.Packet | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- cl.Read(func(cl *Client, pk packets.Packet) error { | ||||
| 			pks = append(pks, pk) | ||||
| 			return nil | ||||
| 		}) | ||||
| 	}() | ||||
|  | ||||
| 	err := <-o | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, io.EOF) | ||||
| 	require.Equal(t, 2, len(pks)) | ||||
| 	require.Equal(t, []packets.Packet{ | ||||
| 		{ | ||||
| 			ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 			FixedHeader: packets.FixedHeader{ | ||||
| 				Type:      packets.Publish, | ||||
| 				Remaining: 18, | ||||
| 			}, | ||||
| 			TopicName: "a/b/c", | ||||
| 			Payload:   []byte("hello mochi"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 			FixedHeader: packets.FixedHeader{ | ||||
| 				Type:      packets.Publish, | ||||
| 				Remaining: 11, | ||||
| 			}, | ||||
| 			TopicName: "d/e/f", | ||||
| 			Payload:   []byte("yeah"), | ||||
| 		}, | ||||
| 	}, pks) | ||||
|  | ||||
| 	require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived)) | ||||
| } | ||||
|  | ||||
| func TestClientReadDone(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	cl.State.done = 1 | ||||
|  | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- cl.Read(func(cl *Client, pk packets.Packet) error { | ||||
| 			return nil | ||||
| 		}) | ||||
| 	}() | ||||
|  | ||||
| 	require.NoError(t, <-o) | ||||
| } | ||||
|  | ||||
| func TestClientStop(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Stop(nil) | ||||
| 	require.Equal(t, nil, cl.State.stopCause.Load()) | ||||
| 	require.Equal(t, time.Now().Unix(), cl.State.disconnected) | ||||
| 	require.Equal(t, uint32(1), cl.State.done) | ||||
| 	require.Equal(t, nil, cl.StopCause()) | ||||
| } | ||||
|  | ||||
| func TestClientReadFixedHeaderError(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			packets.Publish << 4, 11, // Fixed header | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	cl.Net.bconn = nil | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, ErrConnectionClosed, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadReadHandlerErr(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			packets.Publish << 4, 11, // Fixed header | ||||
| 			0, 5, // Topic Name - LSB+MSB | ||||
| 			'd', '/', 'e', '/', 'f', // Topic Name | ||||
| 			'y', 'e', 'a', 'h', // Payload | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	err := cl.Read(func(cl *Client, pk packets.Packet) error { | ||||
| 		return errors.New("test") | ||||
| 	}) | ||||
|  | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadReadPacketOK(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			packets.Publish << 4, 11, // Fixed header | ||||
| 			0, 5, | ||||
| 			'd', '/', 'e', '/', 'f', | ||||
| 			'y', 'e', 'a', 'h', | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	fh := new(packets.FixedHeader) | ||||
| 	err := cl.ReadFixedHeader(fh) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pk, err := cl.ReadPacket(fh) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, pk) | ||||
|  | ||||
| 	require.Equal(t, packets.Packet{ | ||||
| 		ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Type:      packets.Publish, | ||||
| 			Remaining: 11, | ||||
| 		}, | ||||
| 		TopicName: "d/e/f", | ||||
| 		Payload:   []byte("yeah"), | ||||
| 	}, pk) | ||||
| } | ||||
|  | ||||
| func TestClientReadPacket(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
|  | ||||
| 	for _, tx := range pkTable { | ||||
| 		tt := tx // avoid data race | ||||
| 		t.Run(tt.Desc, func(t *testing.T) { | ||||
| 			atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0) | ||||
| 			go func() { | ||||
| 				r.Write(tt.RawBytes) | ||||
| 			}() | ||||
|  | ||||
| 			fh := new(packets.FixedHeader) | ||||
| 			err := cl.ReadFixedHeader(fh) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			if tt.Packet.ProtocolVersion == 5 { | ||||
| 				cl.Properties.ProtocolVersion = 5 | ||||
| 			} else { | ||||
| 				cl.Properties.ProtocolVersion = 0 | ||||
| 			} | ||||
|  | ||||
| 			pk, err := cl.ReadPacket(fh) | ||||
| 			require.NoError(t, err, pkInfo, tt.Case, tt.Desc) | ||||
| 			require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc) | ||||
| 			require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| 			if tt.Packet.FixedHeader.Type == packets.Publish { | ||||
| 				require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestClientReadPacketInvalidTypeError(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Net.conn.Close() | ||||
| 	_, err := cl.ReadPacket(&packets.FixedHeader{}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Contains(t, err.Error(), "invalid packet type") | ||||
| } | ||||
|  | ||||
| func TestClientWritePacket(t *testing.T) { | ||||
| 	for _, tt := range pkTable { | ||||
| 		cl, r, _ := newTestClient() | ||||
| 		defer cl.Stop(errClientStop) | ||||
|  | ||||
| 		cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion | ||||
|  | ||||
| 		o := make(chan []byte) | ||||
| 		go func() { | ||||
| 			buf, err := io.ReadAll(r) | ||||
| 			require.NoError(t, err) | ||||
| 			o <- buf | ||||
| 		}() | ||||
|  | ||||
| 		err := cl.WritePacket(*tt.Packet) | ||||
| 		require.NoError(t, err, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| 		time.Sleep(2 * time.Millisecond) | ||||
| 		cl.Net.conn.Close() | ||||
|  | ||||
| 		require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| 		cl.Stop(errClientStop) | ||||
| 		time.Sleep(time.Millisecond * 1) | ||||
|  | ||||
| 		// The stop cause is either the test error, EOF, or a | ||||
| 		// closed pipe, depending on which goroutine runs first. | ||||
| 		err = cl.StopCause() | ||||
| 		require.True(t, | ||||
| 			errors.Is(err, errClientStop) || | ||||
| 				errors.Is(err, io.EOF) || | ||||
| 				errors.Is(err, io.ErrClosedPipe)) | ||||
|  | ||||
| 		require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent)) | ||||
| 		require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent)) | ||||
| 		if tt.Packet.FixedHeader.Type == packets.Publish { | ||||
| 			require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWriteClientOversizePacket(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Properties.Props.MaximumPacketSize = 2 | ||||
| 	pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet | ||||
| 	err := cl.WritePacket(pk) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, packets.ErrPacketTooLarge, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadPacketReadingError(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			0, 11, // Fixed header | ||||
| 			0, 5, | ||||
| 			'd', '/', 'e', '/', 'f', | ||||
| 			'y', 'e', 'a', 'h', | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	_, err := cl.ReadPacket(&packets.FixedHeader{ | ||||
| 		Type:      0, | ||||
| 		Remaining: 11, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientReadPacketReadUnknown(t *testing.T) { | ||||
| 	cl, r, _ := newTestClient() | ||||
| 	defer cl.Stop(errClientStop) | ||||
| 	go func() { | ||||
| 		r.Write([]byte{ | ||||
| 			0, 11, // Fixed header | ||||
| 			0, 5, | ||||
| 			'd', '/', 'e', '/', 'f', | ||||
| 			'y', 'e', 'a', 'h', | ||||
| 		}) | ||||
| 		r.Close() | ||||
| 	}() | ||||
|  | ||||
| 	_, err := cl.ReadPacket(&packets.FixedHeader{ | ||||
| 		Remaining: 1, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientWritePacketWriteNoConn(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Stop(errClientStop) | ||||
|  | ||||
| 	err := cl.WritePacket(*pkTable[1].Packet) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, ErrConnectionClosed, err) | ||||
| } | ||||
|  | ||||
| func TestClientWritePacketWriteError(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.Net.conn.Close() | ||||
|  | ||||
| 	err := cl.WritePacket(*pkTable[1].Packet) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientWritePacketInvalidPacket(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	err := cl.WritePacket(packets.Packet{}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	pkTable = []packets.TPacketCase{ | ||||
| 		packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311), | ||||
| 		packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5), | ||||
| 		packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession), | ||||
| 		packets.TPacketData[packets.Publish].Get(packets.TPublishBasic), | ||||
| 		packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5), | ||||
| 		packets.TPacketData[packets.Puback].Get(packets.TPuback), | ||||
| 		packets.TPacketData[packets.Pubrec].Get(packets.TPubrec), | ||||
| 		packets.TPacketData[packets.Pubrel].Get(packets.TPubrel), | ||||
| 		packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp), | ||||
| 		packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe), | ||||
| 		packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5), | ||||
| 		packets.TPacketData[packets.Suback].Get(packets.TSuback), | ||||
| 		packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5), | ||||
| 		packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe), | ||||
| 		packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5), | ||||
| 		packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback), | ||||
| 		packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5), | ||||
| 		packets.TPacketData[packets.Pingreq].Get(packets.TPingreq), | ||||
| 		packets.TPacketData[packets.Pingresp].Get(packets.TPingresp), | ||||
| 		packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect), | ||||
| 		packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5), | ||||
| 		packets.TPacketData[packets.Auth].Get(packets.TAuth), | ||||
| 	} | ||||
| ) | ||||
							
								
								
									
										45
									
								
								cmd/main.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								cmd/main.go
									
									
									
									
									
								
							| @@ -1,17 +1,19 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| @@ -28,37 +30,36 @@ func main() { | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Broker initializing...")) | ||||
| 	fmt.Println(aurora.Cyan("TCP"), *tcpAddr) | ||||
| 	fmt.Println(aurora.Cyan("Websocket"), *wsAddr) | ||||
| 	fmt.Println(aurora.Cyan("$SYS Dashboard"), *infoAddr) | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr) | ||||
| 	err := server.AddListener(tcp, nil) | ||||
| 	tcp := listeners.NewTCP("t1", *tcpAddr, nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	ws := listeners.NewWebsocket("ws1", *wsAddr) | ||||
| 	err = server.AddListener(ws, nil) | ||||
| 	ws := listeners.NewWebsocket("ws1", *wsAddr, nil) | ||||
| 	err = server.AddListener(ws) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	stats := listeners.NewHTTPStats("stats", *infoAddr) | ||||
| 	err = server.AddListener(stats, nil) | ||||
| 	stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info) | ||||
| 	err = server.AddListener(stats) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go server.Serve() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
|  | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|   | ||||
							
								
								
									
										83
									
								
								examples/auth/basic/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								examples/auth/basic/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	authRules := &auth.Ledger{ | ||||
| 		Auth: auth.AuthRules{ // Auth disallows all by default | ||||
| 			{Username: "peach", Password: "password1", Allow: true}, | ||||
| 			{Username: "melon", Password: "password2", Allow: true}, | ||||
| 			{Remote: "127.0.0.1:*", Allow: true}, | ||||
| 			{Remote: "localhost:*", Allow: true}, | ||||
| 		}, | ||||
| 		ACL: auth.ACLRules{ // ACL allows all by default | ||||
| 			{Remote: "127.0.0.1:*"}, // local superuser allow all | ||||
| 			{ | ||||
| 				// user melon can read and write to their own topic | ||||
| 				Username: "melon", Filters: auth.Filters{ | ||||
| 					"melon/#":   auth.ReadWrite, | ||||
| 					"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others | ||||
| 				}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				// Otherwise, no clients have publishing permissions | ||||
| 				Filters: auth.Filters{ | ||||
| 					"#":         auth.ReadOnly, | ||||
| 					"updates/#": auth.Deny, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// you may also find this useful... | ||||
| 	// d, _ := authRules.ToYAML() | ||||
| 	// d, _ := authRules.ToJSON() | ||||
| 	// fmt.Println(string(d)) | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	err := server.AddHook(new(auth.Hook), &auth.Options{ | ||||
| 		Ledger: authRules, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
							
								
								
									
										40
									
								
								examples/auth/encoded/auth.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								examples/auth/encoded/auth.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| { | ||||
|   "auth": [ | ||||
|     { | ||||
|       "username": "peach", | ||||
|       "password": "password1", | ||||
|       "allow": true | ||||
|     }, | ||||
|     { | ||||
|       "username": "melon", | ||||
|       "password": "password2", | ||||
|       "allow": true | ||||
|     }, | ||||
|     { | ||||
|       "remote": "127.0.0.1:*", | ||||
|       "allow": false | ||||
|     }, | ||||
|     { | ||||
|       "remote": "localhost:*", | ||||
|       "allow": false | ||||
|     } | ||||
|   ], | ||||
|   "acl": [ | ||||
|     { | ||||
|       "remote": "127.0.0.1:*" | ||||
|     }, | ||||
|     { | ||||
|       "username": "melon", | ||||
|       "filters": { | ||||
|         "melon/#": 3, | ||||
|         "updates/#": 2 | ||||
|       } | ||||
|     }, | ||||
|     { | ||||
|       "filters": { | ||||
|         "#": 1, | ||||
|         "updates/#": 0 | ||||
|       } | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										21
									
								
								examples/auth/encoded/auth.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								examples/auth/encoded/auth.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| auth: | ||||
|     - username: peach | ||||
|       password: password1 | ||||
|       allow: true | ||||
|     - username: melon | ||||
|       password: password2 | ||||
|       allow: true | ||||
|     # - remote: 127.0.0.1:* | ||||
|     #   allow: true | ||||
|     # - remote: localhost:* | ||||
|     #   allow: true | ||||
| acl: | ||||
| # 0 = deny, 1 = read only, 2 = write only, 3 = read and write | ||||
|     - remote: 127.0.0.1:* | ||||
|     - username: melon | ||||
|       filters: | ||||
|         melon/#: 3 | ||||
|         updates/#: 2 | ||||
|     - filters: | ||||
|         '#': 1 | ||||
|         updates/#: 0 | ||||
							
								
								
									
										65
									
								
								examples/auth/encoded/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								examples/auth/encoded/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	// You can also run from top-level server.go folder: | ||||
| 	// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml | ||||
| 	// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json | ||||
| 	path := flag.String("path", "auth.yaml", "path to data auth file") | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	// Get ledger from yaml file | ||||
| 	data, err := os.ReadFile(*path) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	err = server.AddHook(new(auth.Hook), &auth.Options{ | ||||
| 		Data: data, // build ledger from byte slice, yaml or json | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
| @@ -1,49 +0,0 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP")) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
|  | ||||
| 	stats := listeners.NewHTTPStats("stats", ":8080") | ||||
| 	err := server.AddListener(stats, nil) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
|  | ||||
| } | ||||
							
								
								
									
										62
									
								
								examples/debug/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								examples/debug/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/debug" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	l := server.Log.Level(zerolog.DebugLevel) | ||||
| 	server.Log = &l | ||||
|  | ||||
| 	err := server.AddHook(new(auth.AllowHook), nil) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	err = server.AddHook(new(debug.Hook), &debug.Options{ | ||||
| 		ShowPacketData: true, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
| @@ -1,77 +0,0 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/events" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP")) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	err := server.AddListener(tcp, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Start the server | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// Add OnMessage Event Hook | ||||
| 	server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) { | ||||
| 		pkx = pk | ||||
| 		if string(pk.Payload) == "hello" { | ||||
| 			pkx.Payload = []byte("hello world") | ||||
| 			fmt.Printf("< OnMessage modified message from client %s: %s\n", cl.ID, string(pkx.Payload)) | ||||
| 		} else { | ||||
| 			fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload)) | ||||
| 		} | ||||
|  | ||||
| 		return pkx, nil | ||||
| 	} | ||||
|  | ||||
| 	// Demonstration of directly publishing messages to a topic via the | ||||
| 	// `server.Publish` method. Subscribe to `direct/publish` using your | ||||
| 	// MQTT client to see the messages. | ||||
| 	go func() { | ||||
| 		for range time.Tick(time.Second * 10) { | ||||
| 			server.Publish("direct/publish", []byte("scheduled message"), false) | ||||
| 			fmt.Println("> issued direct message to direct/publish") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
| } | ||||
							
								
								
									
										143
									
								
								examples/hooks/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								examples/hooks/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	err = server.AddHook(new(ExampleHook), map[string]any{}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Start the server | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// Demonstration of directly publishing messages to a topic via the | ||||
| 	// `server.Publish` method. Subscribe to `direct/publish` using your | ||||
| 	// MQTT client to see the messages. | ||||
| 	go func() { | ||||
| 		cl := server.NewClient(nil, "local", "inline", true) | ||||
| 		for range time.Tick(time.Second * 1) { | ||||
| 			err := server.InjectPacket(cl, packets.Packet{ | ||||
| 				FixedHeader: packets.FixedHeader{ | ||||
| 					Type: packets.Publish, | ||||
| 				}, | ||||
| 				TopicName: "direct/publish", | ||||
| 				Payload:   []byte("injected scheduled message"), | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				server.Log.Error().Err(err).Msg("server.InjectPacket") | ||||
| 			} | ||||
| 			server.Log.Info().Msgf("main.go injected packet to direct/publish") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// There is also a shorthand convenience function, Publish, for easily sending | ||||
| 	// publish packets if you are not concerned with creating your own packets. | ||||
| 	go func() { | ||||
| 		for range time.Tick(time.Second * 5) { | ||||
| 			err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) | ||||
| 			if err != nil { | ||||
| 				server.Log.Error().Err(err).Msg("server.Publish") | ||||
| 			} | ||||
| 			server.Log.Info().Msgf("main.go issued direct message to direct/publish") | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|  | ||||
| type ExampleHook struct { | ||||
| 	mqtt.HookBase | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) ID() string { | ||||
| 	return "events-example" | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnConnect, | ||||
| 		mqtt.OnDisconnect, | ||||
| 		mqtt.OnSubscribed, | ||||
| 		mqtt.OnUnsubscribed, | ||||
| 		mqtt.OnPublished, | ||||
| 		mqtt.OnPublish, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) Init(config any) error { | ||||
| 	h.Log.Info().Msg("initialised") | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Msgf("client connected") | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected") | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes) | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed") | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client") | ||||
|  | ||||
| 	pkx := pk | ||||
| 	if string(pk.Payload) == "hello" { | ||||
| 		pkx.Payload = []byte("hello world") | ||||
| 		h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client") | ||||
| 	} | ||||
|  | ||||
| 	return pkx, nil | ||||
| } | ||||
|  | ||||
| func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client") | ||||
| } | ||||
| @@ -1,16 +1,19 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"bytes" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| @@ -22,41 +25,51 @@ func main() { | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("PAHO Testing Suite")) | ||||
| 	server := mqtt.New(nil) | ||||
| 	server.Options.Capabilities.ServerKeepAlive = 60 | ||||
| 	server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true | ||||
| 	server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true | ||||
| 	server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	err := server.AddListener(tcp, &listeners.Config{ | ||||
| 		Auth: new(Auth), | ||||
| 	}) | ||||
| 	_ = server.AddHook(new(pahoAuthHook), nil) | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err := server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go server.Serve() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|  | ||||
| // Auth is an example auth provider for the server. | ||||
| type Auth struct{} | ||||
| type pahoAuthHook struct { | ||||
| 	mqtt.HookBase | ||||
| } | ||||
|  | ||||
| // Auth returns true if a username and password are acceptable. | ||||
| // Auth always returns true. | ||||
| func (a *Auth) Authenticate(user, password []byte) bool { | ||||
| func (h *pahoAuthHook) ID() string { | ||||
| 	return "allow-all-auth" | ||||
| } | ||||
|  | ||||
| func (h *pahoAuthHook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnConnectAuthenticate, | ||||
| 		mqtt.OnACLCheck, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // ACL returns true if a user has access permissions to read or write on a topic. | ||||
| // ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure. | ||||
| func (a *Auth) ACL(user []byte, topic string, write bool) bool { | ||||
| 	if topic == "test/nosubscribe" { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { | ||||
| 	return topic != "test/nosubscribe" | ||||
| } | ||||
|   | ||||
							
								
								
									
										59
									
								
								examples/persistence/badger/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								examples/persistence/badger/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage/badger" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	badgerPath := ".badger" | ||||
| 	defer os.RemoveAll(badgerPath) // remove the example badger files at the end | ||||
|  | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	err := server.AddHook(new(badger.Hook), &badger.Options{ | ||||
| 		Path: badgerPath, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
|  | ||||
| } | ||||
							
								
								
									
										57
									
								
								examples/persistence/bolt/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								examples/persistence/bolt/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage/bolt" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| 	"go.etcd.io/bbolt" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	err := server.AddHook(new(bolt.Hook), bolt.Options{ | ||||
| 		Path: "bolt.db", | ||||
| 		Options: &bbolt.Options{ | ||||
| 			Timeout: 500 * time.Millisecond, | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
| @@ -1,60 +0,0 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
| 	"go.etcd.io/bbolt" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/persistence/bolt" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence")) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	err := server.AddListener(tcp, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	err = server.AddStore(bolt.New("mochi-test.db", &bbolt.Options{ | ||||
| 		Timeout: 500 * time.Millisecond, | ||||
| 	})) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
| } | ||||
							
								
								
									
										65
									
								
								examples/persistence/redis/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								examples/persistence/redis/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage/redis" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| 	"github.com/rs/zerolog" | ||||
|  | ||||
| 	rv8 "github.com/go-redis/redis/v8" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	sigs := make(chan os.Signal, 1) | ||||
| 	done := make(chan bool, 1) | ||||
| 	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | ||||
| 	go func() { | ||||
| 		<-sigs | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
| 	l := server.Log.Level(zerolog.DebugLevel) | ||||
| 	server.Log = &l | ||||
|  | ||||
| 	err := server.AddHook(new(redis.Hook), &redis.Options{ | ||||
| 		Options: &rv8.Options{ | ||||
| 			Addr:     "localhost:6379", // default redis address | ||||
| 			Password: "",               // your password | ||||
| 			DB:       0,                // your redis db | ||||
| 		}, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
|  | ||||
| } | ||||
| @@ -1,17 +1,18 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| @@ -23,13 +24,22 @@ func main() { | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP")) | ||||
| 	// An example of configuring various server options... | ||||
| 	options := &mqtt.Options{ | ||||
| 		// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages | ||||
| 	} | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	err := server.AddListener(tcp, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 	}) | ||||
| 	server := mqtt.New(options) | ||||
|  | ||||
| 	// For security reasons, the default implementation disallows all connections. | ||||
| 	// If you want to allow all connections, you must specifically allow it. | ||||
| 	err := server.AddHook(new(auth.AllowHook), nil) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", nil) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| @@ -40,12 +50,9 @@ func main() { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
|  | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|   | ||||
| @@ -1,17 +1,19 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"crypto/tls" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -55,51 +57,61 @@ func main() { | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL")) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	tcp := listeners.NewTCP("t1", ":1883") | ||||
| 	err := server.AddListener(tcp, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &listeners.TLS{ | ||||
| 			Certificate: testCertificate, | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	}) | ||||
| 	cert, err := tls.X509KeyPair(testCertificate, testPrivateKey) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	ws := listeners.NewWebsocket("ws1", ":1882") | ||||
| 	err = server.AddListener(ws, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &listeners.TLS{ | ||||
| 			Certificate: testCertificate, | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	// Basic TLS Config | ||||
| 	tlsConfig := &tls.Config{ | ||||
| 		Certificates: []tls.Certificate{cert}, | ||||
| 	} | ||||
|  | ||||
| 	// Optionally, if you want clients to authenticate only with certs issued by your CA, | ||||
| 	// you might want to use something like this: | ||||
| 	// certPool := x509.NewCertPool() | ||||
| 	// _ = certPool.AppendCertsFromPEM(caCertPem) | ||||
| 	// tlsConfig := &tls.Config{ | ||||
| 	//	ClientCAs: certPool, | ||||
| 	// 	ClientAuth: tls.RequireAndVerifyClientCert, | ||||
| 	// } | ||||
|  | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{ | ||||
| 		TLSConfig: tlsConfig, | ||||
| 	}) | ||||
| 	err = server.AddListener(tcp) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	stats := listeners.NewHTTPStats("stats", ":8080") | ||||
| 	err = server.AddListener(stats, &listeners.Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &listeners.TLS{ | ||||
| 			Certificate: testCertificate, | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{ | ||||
| 		TLSConfig: tlsConfig, | ||||
| 	}) | ||||
| 	err = server.AddListener(ws) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go server.Serve() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
| 	stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{ | ||||
| 		TLSConfig: tlsConfig, | ||||
| 	}, nil) | ||||
| 	err = server.AddListener(stats) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		err := server.Serve() | ||||
| 		if err != nil { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|   | ||||
| @@ -1,16 +1,18 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"syscall" | ||||
|  | ||||
| 	"github.com/logrusorgru/aurora" | ||||
|  | ||||
| 	mqtt "github.com/mochi-co/mqtt/server" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/auth" | ||||
| 	"github.com/mochi-co/mqtt/v2/listeners" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| @@ -22,11 +24,11 @@ func main() { | ||||
| 		done <- true | ||||
| 	}() | ||||
|  | ||||
| 	fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP")) | ||||
| 	server := mqtt.New(nil) | ||||
| 	_ = server.AddHook(new(auth.AllowHook), nil) | ||||
|  | ||||
| 	server := mqtt.New() | ||||
| 	ws := listeners.NewWebsocket("ws1", ":1882") | ||||
| 	err := server.AddListener(ws, nil) | ||||
| 	ws := listeners.NewWebsocket("ws1", ":1882", nil) | ||||
| 	err := server.AddListener(ws) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
| @@ -37,11 +39,9 @@ func main() { | ||||
| 			log.Fatal(err) | ||||
| 		} | ||||
| 	}() | ||||
| 	fmt.Println(aurora.BgMagenta("  Started!  ")) | ||||
|  | ||||
| 	<-done | ||||
| 	fmt.Println(aurora.BgRed("  Caught Signal  ")) | ||||
|  | ||||
| 	server.Log.Warn().Msg("caught signal, stopping...") | ||||
| 	server.Close() | ||||
| 	fmt.Println(aurora.BgGreen("  Finished  ")) | ||||
| 	server.Log.Info().Msg("main.go finished") | ||||
| } | ||||
|   | ||||
							
								
								
									
										101
									
								
								fanpool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								fanpool.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co, chowyu08, muXxer | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	xh "github.com/cespare/xxhash/v2" | ||||
| ) | ||||
|  | ||||
| // taskChan is a channel for incoming task functions. | ||||
| type taskChan chan func() | ||||
|  | ||||
| // FanPool is a fixed-sized fan-style worker pool with multiple | ||||
| // working 'columns'. Instead of a single queue channel processed by | ||||
| // many goroutines, this fan pool uses many queue channels each | ||||
| // processed by a single goroutine. | ||||
| // Very special thanks are given to the authors of HMQ in particular | ||||
| // @chowyu08 and @muXxer for their work on the fixpool worker pool | ||||
| // https://github.com/fhmq/hmq/blob/master/pool/fixpool.go | ||||
| // from which this fan-pool is heavily inspired. | ||||
| type FanPool struct { | ||||
| 	queue    []taskChan | ||||
| 	wg       sync.WaitGroup | ||||
| 	capacity uint64 | ||||
| 	perChan  uint64 | ||||
| 	Mutex    sync.Mutex | ||||
| } | ||||
|  | ||||
| // New returns a new instance of FanPool. fanSize controls the number of 'columns' | ||||
| // of the fan, whereas queueSize controls the size of each column's queue. | ||||
| func NewFanPool(fanSize, queueSize uint64) *FanPool { | ||||
| 	pool := &FanPool{ | ||||
| 		capacity: fanSize, | ||||
| 		perChan:  queueSize, | ||||
| 		queue:    make([]taskChan, fanSize), | ||||
| 	} | ||||
|  | ||||
| 	pool.fillWorkers(fanSize) | ||||
|  | ||||
| 	return pool | ||||
| } | ||||
|  | ||||
| // fillWorkers adds columns to the fan pool with an associated worker goroutine. | ||||
| func (p *FanPool) fillWorkers(n uint64) { | ||||
| 	for i := uint64(0); i < n; i++ { | ||||
| 		p.queue[i] = make(taskChan, p.perChan) | ||||
| 		go p.worker(p.queue[i]) | ||||
| 		p.wg.Add(1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // worker is a worker goroutine which processes tasks from a single queue. | ||||
| func (p *FanPool) worker(ch taskChan) { | ||||
| 	defer p.wg.Done() | ||||
| 	var task func() | ||||
| 	var ok bool | ||||
| 	for { | ||||
| 		task, ok = <-ch | ||||
| 		if !ok { | ||||
| 			return | ||||
| 		} | ||||
| 		task() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Enqueue adds a new task to the queue to be processed. | ||||
| func (p *FanPool) Enqueue(id string, task func()) { | ||||
| 	if p.Size() == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// We can use xh.Sum64 to get a specific queue index | ||||
| 	// which remains the same for a client id, giving each | ||||
| 	// client their own queue. | ||||
| 	p.queue[xh.Sum64([]byte(id))%p.Size()] <- task | ||||
| } | ||||
|  | ||||
| // Wait blocks until all the workers in the pool have completed. | ||||
| func (p *FanPool) Wait() { | ||||
| 	p.wg.Wait() | ||||
| } | ||||
|  | ||||
| // Close issues a shutdown signal to the workers. | ||||
| func (p *FanPool) Close() { | ||||
| 	for i := 0; i < int(p.Size()); i++ { | ||||
| 		if p.queue[i] != nil { | ||||
| 			close(p.queue[i]) | ||||
| 		} | ||||
| 	} | ||||
| 	p.queue = nil | ||||
| 	atomic.StoreUint64(&p.capacity, 0) | ||||
| } | ||||
|  | ||||
| // Size returns the current number of workers in the pool. | ||||
| func (p *FanPool) Size() uint64 { | ||||
| 	return atomic.LoadUint64(&p.capacity) | ||||
| } | ||||
							
								
								
									
										89
									
								
								fanpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								fanpool_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,89 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestFanPool(t *testing.T) { | ||||
| 	f := NewFanPool(1, 2) | ||||
| 	require.NotNil(t, f) | ||||
| 	require.Equal(t, uint64(1), f.capacity) | ||||
| 	require.Equal(t, 2, cap(f.queue[0])) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func() { | ||||
| 		f.Enqueue("test", func() { | ||||
| 			o <- true | ||||
| 		}) | ||||
| 	}() | ||||
|  | ||||
| 	require.True(t, <-o) | ||||
| 	f.Close() | ||||
| 	f.Wait() | ||||
| } | ||||
|  | ||||
| func TestFillWorkers(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		perChan: 3, | ||||
| 		queue:   make([]taskChan, 2), | ||||
| 	} | ||||
| 	f.fillWorkers(2) | ||||
| 	require.Len(t, f.queue, 2) | ||||
| 	require.Equal(t, 3, cap(f.queue[0])) | ||||
| } | ||||
|  | ||||
| func TestEnqueue(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 2, | ||||
| 		queue: []taskChan{ | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		f.Enqueue("a", func() {}) | ||||
| 	}() | ||||
| 	require.NotNil(t, <-f.queue[1]) | ||||
| } | ||||
|  | ||||
| func TestEnqueueOnEmpty(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		queue: []taskChan{}, | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		f.Enqueue("a", func() {}) | ||||
| 	}() | ||||
|  | ||||
| 	require.Len(t, f.queue, 0) | ||||
| } | ||||
|  | ||||
| func TestSize(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 10, | ||||
| 	} | ||||
|  | ||||
| 	require.Equal(t, uint64(10), f.Size()) | ||||
| } | ||||
|  | ||||
| func TestClose(t *testing.T) { | ||||
| 	f := &FanPool{ | ||||
| 		capacity: 3, | ||||
| 		queue: []taskChan{ | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 			make(taskChan, 2), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	f.Close() | ||||
| 	require.Equal(t, uint64(0), f.Size()) | ||||
| 	require.Nil(t, f.queue) | ||||
| } | ||||
							
								
								
									
										39
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								go.mod
									
									
									
									
									
								
							| @@ -1,21 +1,40 @@ | ||||
| module github.com/mochi-co/mqtt | ||||
| module github.com/mochi-co/mqtt/v2 | ||||
|  | ||||
| go 1.17 | ||||
| go 1.19 | ||||
|  | ||||
| require ( | ||||
| 	github.com/alicebob/miniredis/v2 v2.23.0 | ||||
| 	github.com/asdine/storm v2.1.2+incompatible | ||||
| 	github.com/asdine/storm/v3 v3.2.1 | ||||
| 	github.com/gorilla/websocket v1.4.2 | ||||
| 	github.com/jinzhu/copier v0.3.4 | ||||
| 	github.com/logrusorgru/aurora v2.0.3+incompatible | ||||
| 	github.com/rs/xid v1.3.0 | ||||
| 	github.com/stretchr/testify v1.7.0 | ||||
| 	go.etcd.io/bbolt v1.3.6 | ||||
| 	github.com/cespare/xxhash/v2 v2.1.2 | ||||
| 	github.com/go-redis/redis/v8 v8.11.5 | ||||
| 	github.com/gorilla/websocket v1.5.0 | ||||
| 	github.com/jinzhu/copier v0.3.5 | ||||
| 	github.com/rs/xid v1.4.0 | ||||
| 	github.com/rs/zerolog v1.28.0 | ||||
| 	github.com/stretchr/testify v1.7.1 | ||||
| 	github.com/timshannon/badgerhold v1.0.0 | ||||
| 	go.etcd.io/bbolt v1.3.5 | ||||
| 	gopkg.in/yaml.v3 v3.0.1 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect | ||||
| 	github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/dgraph-io/badger v1.6.0 // indirect | ||||
| 	github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect | ||||
| 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | ||||
| 	github.com/dustin/go-humanize v1.0.0 // indirect | ||||
| 	github.com/golang/protobuf v1.5.0 // indirect | ||||
| 	github.com/golang/snappy v0.0.3 // indirect | ||||
| 	github.com/mattn/go-colorable v0.1.12 // indirect | ||||
| 	github.com/mattn/go-isatty v0.0.14 // indirect | ||||
| 	github.com/pkg/errors v0.9.1 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect | ||||
| 	github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect | ||||
| 	golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect | ||||
| 	golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect | ||||
| 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect | ||||
| 	google.golang.org/protobuf v1.28.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										121
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,58 +1,143 @@ | ||||
| github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= | ||||
| github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M= | ||||
| github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= | ||||
| github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= | ||||
| github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM= | ||||
| github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= | ||||
| github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM= | ||||
| github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM= | ||||
| github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= | ||||
| github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= | ||||
| github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw= | ||||
| github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88= | ||||
| github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= | ||||
| github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q= | ||||
| github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ= | ||||
| github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac= | ||||
| github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0= | ||||
| github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | ||||
| github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | ||||
| github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= | ||||
| github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= | ||||
| github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= | ||||
| github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= | ||||
| github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= | ||||
| github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= | ||||
| github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= | ||||
| github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= | ||||
| github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||||
| github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo= | ||||
| github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= | ||||
| github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= | ||||
| github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= | ||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | ||||
| github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | ||||
| github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= | ||||
| github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= | ||||
| github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= | ||||
| github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | ||||
| github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= | ||||
| github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= | ||||
| github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= | ||||
| github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= | ||||
| github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | ||||
| github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= | ||||
| github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= | ||||
| github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | ||||
| github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= | ||||
| github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/jinzhu/copier v0.3.4 h1:mfU6jI9PtCeUjkjQ322dlff9ELjGDu975C2p/nrubVI= | ||||
| github.com/jinzhu/copier v0.3.4/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= | ||||
| github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= | ||||
| github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= | ||||
| github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= | ||||
| github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||
| github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= | ||||
| github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= | ||||
| github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg= | ||||
| github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= | ||||
| github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= | ||||
| github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | ||||
| github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= | ||||
| github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= | ||||
| github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= | ||||
| github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8= | ||||
| github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= | ||||
| github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= | ||||
| github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= | ||||
| github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= | ||||
| github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= | ||||
| github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= | ||||
| github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= | ||||
| github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= | ||||
| github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | ||||
| github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= | ||||
| github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= | ||||
| github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= | ||||
| github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||
| github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||
| github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= | ||||
| github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= | ||||
| github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= | ||||
| github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= | ||||
| github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= | ||||
| github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= | ||||
| github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= | ||||
| github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= | ||||
| github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= | ||||
| github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= | ||||
| github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= | ||||
| github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= | ||||
| github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= | ||||
| github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= | ||||
| github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||||
| github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= | ||||
| github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= | ||||
| github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||||
| github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8= | ||||
| github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE= | ||||
| github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= | ||||
| github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= | ||||
| github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= | ||||
| github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= | ||||
| github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw= | ||||
| github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= | ||||
| go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= | ||||
| go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= | ||||
| go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= | ||||
| go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= | ||||
| go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= | ||||
| golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= | ||||
| golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg= | ||||
| golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ= | ||||
| golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= | ||||
| golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d h1:L/IKR6COd7ubZrs2oTnTi73IhgqJ71c9s80WsQnh0Es= | ||||
| golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= | ||||
| golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= | ||||
| golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||||
| golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= | ||||
| google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= | ||||
| google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | ||||
| google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= | ||||
| google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= | ||||
| gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= | ||||
| gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | ||||
| gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | ||||
| gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= | ||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
| gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
|   | ||||
							
								
								
									
										784
									
								
								hooks.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										784
									
								
								hooks.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,784 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	SetOptions byte = iota | ||||
| 	OnSysInfoTick | ||||
| 	OnStarted | ||||
| 	OnStopped | ||||
| 	OnConnectAuthenticate | ||||
| 	OnACLCheck | ||||
| 	OnConnect | ||||
| 	OnSessionEstablished | ||||
| 	OnDisconnect | ||||
| 	OnAuthPacket | ||||
| 	OnPacketRead | ||||
| 	OnPacketEncode | ||||
| 	OnPacketSent | ||||
| 	OnPacketProcessed | ||||
| 	OnSubscribe | ||||
| 	OnSubscribed | ||||
| 	OnSelectSubscribers | ||||
| 	OnUnsubscribe | ||||
| 	OnUnsubscribed | ||||
| 	OnPublish | ||||
| 	OnPublished | ||||
| 	OnRetainMessage | ||||
| 	OnQosPublish | ||||
| 	OnQosComplete | ||||
| 	OnQosDropped | ||||
| 	OnWill | ||||
| 	OnWillSent | ||||
| 	OnClientExpired | ||||
| 	OnRetainedExpired | ||||
| 	OnExpireInflights | ||||
| 	StoredClients | ||||
| 	StoredSubscriptions | ||||
| 	StoredInflightMessages | ||||
| 	StoredRetainedMessages | ||||
| 	StoredSysInfo | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// ErrInvalidConfigType indicates a different Type of config value was expected to what was received. | ||||
| 	ErrInvalidConfigType = errors.New("invalid config type provided") | ||||
| ) | ||||
|  | ||||
| // Hook provides an interface of handlers for different events which occur | ||||
| // during the lifecycle of the broker. | ||||
| type Hook interface { | ||||
| 	ID() string | ||||
| 	Provides(b byte) bool | ||||
| 	Init(config any) error | ||||
| 	Stop() error | ||||
| 	SetOpts(l *zerolog.Logger, o *HookOptions) | ||||
| 	OnStarted() | ||||
| 	OnStopped() | ||||
| 	OnConnectAuthenticate(cl *Client, pk packets.Packet) bool | ||||
| 	OnACLCheck(cl *Client, topic string, write bool) bool | ||||
| 	OnSysInfoTick(*system.Info) | ||||
| 	OnConnect(cl *Client, pk packets.Packet) | ||||
| 	OnSessionEstablished(cl *Client, pk packets.Packet) | ||||
| 	OnDisconnect(cl *Client, err error, expire bool) | ||||
| 	OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) | ||||
| 	OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation | ||||
| 	OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet        // modify a packet before it is byte-encoded and written to the client | ||||
| 	OnPacketSent(cl *Client, pk packets.Packet, b []byte)               // triggers when packet bytes have been written to the client | ||||
| 	OnPacketProcessed(cl *Client, pk packets.Packet, err error)         // triggers after a packet from the client been processed (handled) | ||||
| 	OnSubscribe(cl *Client, pk packets.Packet) packets.Packet | ||||
| 	OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) | ||||
| 	OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers | ||||
| 	OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet | ||||
| 	OnUnsubscribed(cl *Client, pk packets.Packet) | ||||
| 	OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) | ||||
| 	OnPublished(cl *Client, pk packets.Packet) | ||||
| 	OnRetainMessage(cl *Client, pk packets.Packet, r int64) | ||||
| 	OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) | ||||
| 	OnQosComplete(cl *Client, pk packets.Packet) | ||||
| 	OnQosDropped(cl *Client, pk packets.Packet) | ||||
| 	OnWill(cl *Client, will Will) (Will, error) | ||||
| 	OnWillSent(cl *Client, pk packets.Packet) | ||||
| 	OnClientExpired(cl *Client) | ||||
| 	OnRetainedExpired(filter string) | ||||
| 	OnExpireInflights(cl *Client, expiry int64) | ||||
| 	StoredClients() ([]storage.Client, error) | ||||
| 	StoredSubscriptions() ([]storage.Subscription, error) | ||||
| 	StoredInflightMessages() ([]storage.Message, error) | ||||
| 	StoredRetainedMessages() ([]storage.Message, error) | ||||
| 	StoredSysInfo() (storage.SystemInfo, error) | ||||
| } | ||||
|  | ||||
| // HookOptions contains values which are inherited from the server on initialisation. | ||||
| type HookOptions struct { | ||||
| 	Capabilities *Capabilities | ||||
| } | ||||
|  | ||||
| // Hooks is a slice of Hook interfaces to be called in sequence. | ||||
| type Hooks struct { | ||||
| 	Log        *zerolog.Logger // a logger for the hook (from the server) | ||||
| 	internal   []Hook          // a slice of hooks | ||||
| 	wg         sync.WaitGroup  // a waitgroup for syncing hook shutdown | ||||
| 	qty        int64           // the number of hooks in use | ||||
| 	sync.Mutex                 // a mutex | ||||
| } | ||||
|  | ||||
| // Len returns the number of hooks added. | ||||
| func (h *Hooks) Len() int64 { | ||||
| 	return atomic.LoadInt64(&h.qty) | ||||
| } | ||||
|  | ||||
| // Provides returns true if any one hook provides any of the requested hook methods. | ||||
| func (h *Hooks) Provides(b ...byte) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 		for _, hb := range b { | ||||
| 			if hook.Provides(hb) { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // Add adds and initializes a new hook. | ||||
| func (h *Hooks) Add(hook Hook, config any) error { | ||||
| 	h.Lock() | ||||
| 	defer h.Unlock() | ||||
| 	if h.internal == nil { | ||||
| 		h.internal = []Hook{} | ||||
| 	} | ||||
|  | ||||
| 	err := hook.Init(config) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err) | ||||
| 	} | ||||
|  | ||||
| 	h.internal = append(h.internal, hook) | ||||
| 	atomic.AddInt64(&h.qty, 1) | ||||
| 	h.wg.Add(1) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Stop indicates all attached hooks to gracefully end. | ||||
| func (h *Hooks) Stop() { | ||||
| 	go func() { | ||||
| 		for _, hook := range h.internal { | ||||
| 			h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook") | ||||
| 			if err := hook.Stop(); err != nil { | ||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook") | ||||
| 			} | ||||
|  | ||||
| 			h.wg.Done() | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	h.wg.Wait() | ||||
| } | ||||
|  | ||||
| // OnSysInfoTick is called when the $SYS topic values are published out. | ||||
| func (h *Hooks) OnSysInfoTick(sys *system.Info) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnSysInfoTick) { | ||||
| 			hook.OnSysInfoTick(sys) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnStarted is called when the server has successfully started. | ||||
| func (h *Hooks) OnStarted() { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnStarted) { | ||||
| 			hook.OnStarted() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnStopped is called when the server has successfully stopped. | ||||
| func (h *Hooks) OnStopped() { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnStopped) { | ||||
| 			hook.OnStopped() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnConnect is called when a new client connects. | ||||
| func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnConnect) { | ||||
| 			hook.OnConnect(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | ||||
| func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnSessionEstablished) { | ||||
| 			hook.OnSessionEstablished(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnDisconnect is called when a client is disconnected for any reason. | ||||
| func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnDisconnect) { | ||||
| 			hook.OnDisconnect(cl, err, expire) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnPacketRead is called when a packet is received from a client. | ||||
| func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPacketRead) { | ||||
| 			npk, err := hook.OnPacketRead(cl, pkx) | ||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected") | ||||
| 				return pk, err | ||||
| 			} else if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			pkx = npk | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // OnAuthPacket is called when an auth packet is received. It is intended to allow developers | ||||
| // to create their own auth packet handling mechanisms. | ||||
| func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnAuthPacket) { | ||||
| 			npk, err := hook.OnAuthPacket(cl, pkx) | ||||
| 			if err != nil { | ||||
| 				return pk, err | ||||
| 			} | ||||
|  | ||||
| 			pkx = npk | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // OnPacketEncode is called immediately before a packet is encoded to be sent to a client. | ||||
| func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPacketEncode) { | ||||
| 			pk = hook.OnPacketEncode(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnPacketProcessed is called when a packet has been received and successfully handled by the broker. | ||||
| func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPacketProcessed) { | ||||
| 			hook.OnPacketProcessed(cl, pk, err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter | ||||
| // containing the bytes sent. | ||||
| func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPacketSent) { | ||||
| 			hook.OnPacketSent(cl, pk, b) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSubscribe is called when a client subscribes to one or more filters. This method | ||||
| // differs from OnSubscribed in that it allows you to modify the subscription values | ||||
| // before the packet is processed. The return values of the hook methods are passed-through | ||||
| // in the order the hooks were attached. | ||||
| func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnSubscribe) { | ||||
| 			pk = hook.OnSubscribe(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnSubscribed is called when a client subscribes to one or more filters. | ||||
| func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnSubscribed) { | ||||
| 			hook.OnSubscribed(cl, pk, reasonCodes) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSelectSubscribers is called when subscribers have been collected for a topic, but before | ||||
| // shared subscription subscribers have been selected. This hook can be used to programmatically | ||||
| // remove or add clients to a publish to subscribers process, or to select the subscriber for a shared | ||||
| // group in a custom manner (such as based on client id, ip, etc). | ||||
| func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnSelectSubscribers) { | ||||
| 			subs = hook.OnSelectSubscribers(subs, pk) | ||||
| 		} | ||||
| 	} | ||||
| 	return subs | ||||
| } | ||||
|  | ||||
| // OnUnsubscribe is called when a client unsubscribes from one or more filters. This method | ||||
| // differs from OnUnsubscribed in that it allows you to modify the unsubscription values | ||||
| // before the packet is processed. The return values of the hook methods are passed-through | ||||
| // in the order the hooks were attached. | ||||
| func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnUnsubscribe) { | ||||
| 			pk = hook.OnUnsubscribe(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnUnsubscribed is called when a client unsubscribes from one or more filters. | ||||
| func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnUnsubscribed) { | ||||
| 			hook.OnUnsubscribed(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnPublish is called when a client publishes a message. This method differs from OnPublished | ||||
| // in that it allows you to modify you to modify the incoming packet before it is processed. | ||||
| // The return values of the hook methods are passed-through in the order the hooks were attached. | ||||
| func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) { | ||||
| 	pkx = pk | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPublish) { | ||||
| 			npk, err := hook.OnPublish(cl, pkx) | ||||
| 			if err != nil && errors.Is(err, packets.ErrRejectPacket) { | ||||
| 				h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected") | ||||
| 				return pk, err | ||||
| 			} else if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			pkx = npk | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // OnPublished is called when a client has published a message to subscribers. | ||||
| func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnPublished) { | ||||
| 			hook.OnPublished(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainMessage is called then a published message is retained. | ||||
| func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnRetainMessage) { | ||||
| 			hook.OnRetainMessage(cl, pk, r) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber. | ||||
| // In other words, this method is called when a new inflight message is created or resent. | ||||
| // It is typically used to store a new inflight message. | ||||
| func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnQosPublish) { | ||||
| 			hook.OnQosPublish(cl, pk, sent, resends) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosComplete is called when the Qos flow for a message has been completed. | ||||
| // In other words, when an inflight message is resolved. | ||||
| // It is typically used to delete an inflight message from a store. | ||||
| func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnQosComplete) { | ||||
| 			hook.OnQosComplete(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosDropped is called the Qos flow for a message expires. In other words, when | ||||
| // an inflight message expires or is abandoned. | ||||
| // It is typically used to delete an inflight message from a store. | ||||
| func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnQosDropped) { | ||||
| 			hook.OnQosDropped(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnWill is called when a client disconnects and publishes an LWT message. This method | ||||
| // differs from OnWillSent in that it allows you to modify the LWT message before it is | ||||
| // published. The return values of the hook methods are passed-through in the order | ||||
| // the hooks were attached. | ||||
| func (h *Hooks) OnWill(cl *Client, will Will) Will { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnWill) { | ||||
| 			mlwt, err := hook.OnWill(cl, will) | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error") | ||||
| 				continue | ||||
| 			} | ||||
| 			will = mlwt | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return will | ||||
| } | ||||
|  | ||||
| // OnWillSent is called when an LWT message has been issued from a disconnecting client. | ||||
| func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnWillSent) { | ||||
| 			hook.OnWillSent(cl, pk) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnClientExpired is called when a client session has expired and should be deleted. | ||||
| func (h *Hooks) OnClientExpired(cl *Client) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnClientExpired) { | ||||
| 			hook.OnClientExpired(cl) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired is called when a retained message has expired and should be deleted. | ||||
| func (h *Hooks) OnRetainedExpired(filter string) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnRetainedExpired) { | ||||
| 			hook.OnRetainedExpired(filter) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // StoredClients returns all clients, e.g. from a persistent store, is used to | ||||
| // populate the server clients list before start. | ||||
| func (h *Hooks) StoredClients() (v []storage.Client, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(StoredClients) { | ||||
| 			v, err := hook.StoredClients() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients") | ||||
| 				return v, err | ||||
| 			} | ||||
|  | ||||
| 			if len(v) > 0 { | ||||
| 				return v, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is | ||||
| // used to populate the server subscriptions list before start. | ||||
| func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(StoredSubscriptions) { | ||||
| 			v, err := hook.StoredSubscriptions() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions") | ||||
| 				return v, err | ||||
| 			} | ||||
|  | ||||
| 			if len(v) > 0 { | ||||
| 				return v, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredInflightMessages returns all inflight messages, e.g. from a persistent store, | ||||
| // and is used to populate the restored clients with inflight messages before start. | ||||
| func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(StoredInflightMessages) { | ||||
| 			v, err := hook.StoredInflightMessages() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages") | ||||
| 				return v, err | ||||
| 			} | ||||
|  | ||||
| 			if len(v) > 0 { | ||||
| 				return v, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredRetainedMessages returns all retained messages, e.g. from a persistent store, | ||||
| // and is used to populate the server topics with retained messages before start. | ||||
| func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(StoredRetainedMessages) { | ||||
| 			v, err := hook.StoredRetainedMessages() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages") | ||||
| 				return v, err | ||||
| 			} | ||||
|  | ||||
| 			if len(v) > 0 { | ||||
| 				return v, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredSysInfo returns a set of system info values. | ||||
| func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(StoredSysInfo) { | ||||
| 			v, err := hook.StoredSysInfo() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info") | ||||
| 				return v, err | ||||
| 			} | ||||
|  | ||||
| 			if v.Version != "" { | ||||
| 				return v, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // OnConnectAuthenticate is called when a user attempts to authenticate with the server. | ||||
| // An implementation of this method MUST be used to allow or deny access to the | ||||
| // server (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||
| // check connecting users against an existing user database. | ||||
| func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnConnectAuthenticate) { | ||||
| 			if ok := hook.OnConnectAuthenticate(cl, pk); ok { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnACLCheck is called when a user attempts to publish or subscribe to a topic filter. | ||||
| // An implementation of this method MUST be used to allow or deny access to the | ||||
| // (see hooks/auth/allow_all or basic). It can be used in custom hooks to | ||||
| // check publishing and subscribing users against an existing permissions or roles database. | ||||
| func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnACLCheck) { | ||||
| 			if ok := hook.OnACLCheck(cl, topic, write); ok { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnExpireInflights is called when the server issues a clear request for expired | ||||
| // inflight messages. Expiry should be the time after which the message is no longer | ||||
| // valid (usually some time in the past). A message has expired if it's created time | ||||
| // is older than time.Now() minus Inflight TTL. This method can be used to expire | ||||
| // old inflight messages in a persistent store which doesnt support per-item TTL. | ||||
| func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) { | ||||
| 	for _, hook := range h.internal { | ||||
| 		if hook.Provides(OnExpireInflights) { | ||||
| 			hook.OnExpireInflights(cl, expiry) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // HookBase provides a set of default methods for each hook. It should be embedded in | ||||
| // all hooks. | ||||
| type HookBase struct { | ||||
| 	Hook | ||||
| 	Log  *zerolog.Logger | ||||
| 	Opts *HookOptions | ||||
| } | ||||
|  | ||||
| // ID returns the ID of the hook. | ||||
| func (h *HookBase) ID() string { | ||||
| 	return "base" | ||||
| } | ||||
|  | ||||
| // Provides indicates which methods a hook provides. The default is none - this method | ||||
| // should be overridden by the embedding hook. | ||||
| func (h *HookBase) Provides(b byte) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // Init performs any pre-start initializations for the hook, such as connecting to databases | ||||
| // or opening files. | ||||
| func (h *HookBase) Init(config any) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetOpts is called by the server to propagate internal values and generally should | ||||
| // not be called manually. | ||||
| func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) { | ||||
| 	h.Log = l | ||||
| 	h.Opts = opts | ||||
| } | ||||
|  | ||||
| // Stop is called to gracefully shutdown the hook. | ||||
| func (h *HookBase) Stop() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnStarted is called when the server starts. | ||||
| func (h *HookBase) OnStarted() {} | ||||
|  | ||||
| // OnStopped is called when the server stops. | ||||
| func (h *HookBase) OnStopped() {} | ||||
|  | ||||
| // OnSysInfoTick is called when the server publishes system info. | ||||
| func (h *HookBase) OnSysInfoTick(*system.Info) {} | ||||
|  | ||||
| // OnConnectAuthenticate is called when a user attempts to authenticate with the server. | ||||
| func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnACLCheck is called when a user attempts to subscribe or publish to a topic. | ||||
| func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnConnect is called when a new client connects. | ||||
| func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnSessionEstablished is called when a new client establishes a session (after OnConnect). | ||||
| func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnDisconnect is called when a client is disconnected for any reason. | ||||
| func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {} | ||||
|  | ||||
| // OnAuthPacket is called when an auth packet is received from the client. | ||||
| func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| // OnPacketRead is called when a packet is received. | ||||
| func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| // OnPacketEncode is called before a packet is byte-encoded and written to the client. | ||||
| func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnPacketSent is called immediately after a packet is written to a client. | ||||
| func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {} | ||||
|  | ||||
| // OnPacketProcessed is called immediately after a packet from a client is processed. | ||||
| func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {} | ||||
|  | ||||
| // OnSubscribe is called when a client subscribes to one or more filters. | ||||
| func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnSubscribed is called when a client subscribes to one or more filters. | ||||
| func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {} | ||||
|  | ||||
| // OnSelectSubscribers is called when selecting subscribers to receive a message. | ||||
| func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers { | ||||
| 	return subs | ||||
| } | ||||
|  | ||||
| // OnUnsubscribe is called when a client unsubscribes from one or more filters. | ||||
| func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet { | ||||
| 	return pk | ||||
| } | ||||
|  | ||||
| // OnUnsubscribed is called when a client unsubscribes from one or more filters. | ||||
| func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnPublish is called when a client publishes a message. | ||||
| func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| // OnPublished is called when a client has published a message to subscribers. | ||||
| func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnRetainMessage is called then a published message is retained. | ||||
| func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {} | ||||
|  | ||||
| // OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber. | ||||
| func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {} | ||||
|  | ||||
| // OnQosComplete is called when the Qos flow for a message has been completed. | ||||
| func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnQosDropped is called the Qos flow for a message expires. | ||||
| func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnWill is called when a client disconnects and publishes an LWT message. | ||||
| func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) { | ||||
| 	return will, nil | ||||
| } | ||||
|  | ||||
| // OnWillSent is called when an LWT message has been issued from a disconnecting client. | ||||
| func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {} | ||||
|  | ||||
| // OnClientExpired is called when a client session has expired. | ||||
| func (h *HookBase) OnClientExpired(cl *Client) {} | ||||
|  | ||||
| // OnRetainedExpired is called when a retained message for a topic has expired. | ||||
| func (h *HookBase) OnRetainedExpired(topic string) {} | ||||
|  | ||||
| // OnExpireInflights is called when the server issues a clear request for expired inflight messages. | ||||
| func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {} | ||||
|  | ||||
| // StoredClients returns all clients from a store. | ||||
| func (h *HookBase) StoredClients() (v []storage.Client, err error) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredSubscriptions returns all subcriptions from a store. | ||||
| func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredInflightMessages returns all inflight messages from a store. | ||||
| func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredRetainedMessages returns all retained messages from a store. | ||||
| func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // StoredSysInfo returns a set of system info values. | ||||
| func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										41
									
								
								hooks/auth/allow_all.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								hooks/auth/allow_all.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| // AllowHook is an authentication hook which allows connection access | ||||
| // for all users and read and write access to all topics. | ||||
| type AllowHook struct { | ||||
| 	mqtt.HookBase | ||||
| } | ||||
|  | ||||
| // ID returns the ID of the hook. | ||||
| func (h *AllowHook) ID() string { | ||||
| 	return "allow-all-auth" | ||||
| } | ||||
|  | ||||
| // Provides indicates which hook methods this hook provides. | ||||
| func (h *AllowHook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnConnectAuthenticate, | ||||
| 		mqtt.OnACLCheck, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| // OnConnectAuthenticate returns true/allowed for all requests. | ||||
| func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // OnACLCheck returns true/allowed for all checks. | ||||
| func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { | ||||
| 	return true | ||||
| } | ||||
							
								
								
									
										35
									
								
								hooks/auth/allow_all_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								hooks/auth/allow_all_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestAllowAllID(t *testing.T) { | ||||
| 	h := new(AllowHook) | ||||
| 	require.Equal(t, "allow-all-auth", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestAllowAllProvides(t *testing.T) { | ||||
| 	h := new(AllowHook) | ||||
| 	require.True(t, h.Provides(mqtt.OnACLCheck)) | ||||
| 	require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) | ||||
| 	require.False(t, h.Provides(mqtt.OnPublished)) | ||||
| } | ||||
|  | ||||
| func TestAllowAllOnConnectAuthenticate(t *testing.T) { | ||||
| 	h := new(AllowHook) | ||||
| 	require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{})) | ||||
| } | ||||
|  | ||||
| func TestAllowAllOnACLCheck(t *testing.T) { | ||||
| 	h := new(AllowHook) | ||||
| 	require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true)) | ||||
| } | ||||
							
								
								
									
										107
									
								
								hooks/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								hooks/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,107 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| // Options contains the configuration/rules data for the auth ledger. | ||||
| type Options struct { | ||||
| 	Data   []byte | ||||
| 	Ledger *Ledger | ||||
| } | ||||
|  | ||||
| // Hook is an authentication hook which implements an auth ledger. | ||||
| type Hook struct { | ||||
| 	mqtt.HookBase | ||||
| 	config *Options | ||||
| 	ledger *Ledger | ||||
| } | ||||
|  | ||||
| // ID returns the ID of the hook. | ||||
| func (h *Hook) ID() string { | ||||
| 	return "auth-ledger" | ||||
| } | ||||
|  | ||||
| // Provides indicates which hook methods this hook provides. | ||||
| func (h *Hook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnConnectAuthenticate, | ||||
| 		mqtt.OnACLCheck, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| // Init configures the hook with the auth ledger to be used for checking. | ||||
| func (h *Hook) Init(config any) error { | ||||
| 	if _, ok := config.(*Options); !ok && config != nil { | ||||
| 		return mqtt.ErrInvalidConfigType | ||||
| 	} | ||||
|  | ||||
| 	if config == nil { | ||||
| 		config = new(Options) | ||||
| 	} | ||||
|  | ||||
| 	h.config = config.(*Options) | ||||
|  | ||||
| 	var err error | ||||
| 	if h.config.Ledger != nil { | ||||
| 		h.ledger = h.config.Ledger | ||||
| 	} else if len(h.config.Data) > 0 { | ||||
| 		h.ledger = new(Ledger) | ||||
| 		err = h.ledger.Unmarshal(h.config.Data) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if h.ledger == nil { | ||||
| 		h.ledger = &Ledger{ | ||||
| 			Auth: AuthRules{}, | ||||
| 			ACL:  ACLRules{}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Info(). | ||||
| 		Int("authentication", len(h.ledger.Auth)). | ||||
| 		Int("acl", len(h.ledger.ACL)). | ||||
| 		Msg("loaded auth rules") | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnConnectAuthenticate returns true if the connecting client has rules which provide access | ||||
| // in the auth ledger. | ||||
| func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { | ||||
| 	if _, ok := h.ledger.AuthOk(cl, pk); ok { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Info(). | ||||
| 		Str("username", string(pk.Connect.Username)). | ||||
| 		Str("remote", cl.Net.Remote). | ||||
| 		Msg("client failed authentication check") | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // OnACLCheck returns true if the connecting client has matching read or write access to subscribe | ||||
| // or publish to a given topic. | ||||
| func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { | ||||
| 	if _, ok := h.ledger.ACLOk(cl, topic, write); ok { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Debug(). | ||||
| 		Str("client", cl.ID). | ||||
| 		Str("username", string(cl.Properties.Username)). | ||||
| 		Str("topic", topic). | ||||
| 		Msg("client failed allowed ACL check") | ||||
|  | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										213
									
								
								hooks/auth/auth_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								hooks/auth/auth_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) | ||||
|  | ||||
| // func teardown(t *testing.T, path string, h *Hook) { | ||||
| // 	h.Stop() | ||||
| // } | ||||
|  | ||||
| func TestBasicID(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.Equal(t, "auth-ledger", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestBasicProvides(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.True(t, h.Provides(mqtt.OnACLCheck)) | ||||
| 	require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) | ||||
| 	require.False(t, h.Provides(mqtt.OnPublish)) | ||||
| } | ||||
|  | ||||
| func TestBasicInitBadConfig(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(map[string]any{}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestBasicInitDefaultConfig(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestBasicInitWithLedgerPointer(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	ln := &Ledger{ | ||||
| 		Auth: []AuthRule{ | ||||
| 			{ | ||||
| 				Remote: "127.0.0.1", | ||||
| 				Allow:  true, | ||||
| 			}, | ||||
| 		}, | ||||
| 		ACL: []ACLRule{ | ||||
| 			{ | ||||
| 				Remote: "127.0.0.1", | ||||
| 				Filters: Filters{ | ||||
| 					"#": ReadWrite, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.Init(&Options{ | ||||
| 		Ledger: ln, | ||||
| 	}) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
| 	require.Same(t, ln, h.ledger) | ||||
| } | ||||
|  | ||||
| func TestBasicInitWithLedgerJSON(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	require.Nil(t, h.ledger) | ||||
| 	err := h.Init(&Options{ | ||||
| 		Data: ledgerJSON, | ||||
| 	}) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) | ||||
| 	require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) | ||||
| } | ||||
|  | ||||
| func TestBasicInitWithLedgerYAML(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	require.Nil(t, h.ledger) | ||||
| 	err := h.Init(&Options{ | ||||
| 		Data: ledgerYAML, | ||||
| 	}) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) | ||||
| 	require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) | ||||
| } | ||||
|  | ||||
| func TestBasicInitWithLedgerBadDAta(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	require.Nil(t, h.ledger) | ||||
| 	err := h.Init(&Options{ | ||||
| 		Data: []byte("fdsfdsafasd"), | ||||
| 	}) | ||||
|  | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestOnConnectAuthenticate(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	ln := new(Ledger) | ||||
| 	ln.Auth = checkLedger.Auth | ||||
| 	ln.ACL = checkLedger.ACL | ||||
| 	err := h.Init( | ||||
| 		&Options{ | ||||
| 			Ledger: ln, | ||||
| 		}, | ||||
| 	) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.True(t, h.OnConnectAuthenticate( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, | ||||
| 	)) | ||||
|  | ||||
| 	require.False(t, h.OnConnectAuthenticate( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, | ||||
| 	)) | ||||
|  | ||||
| 	require.False(t, h.OnConnectAuthenticate( | ||||
| 		&mqtt.Client{}, | ||||
| 		packets.Packet{}, | ||||
| 	)) | ||||
| } | ||||
|  | ||||
| func TestOnACL(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	ln := new(Ledger) | ||||
| 	ln.Auth = checkLedger.Auth | ||||
| 	ln.ACL = checkLedger.ACL | ||||
| 	err := h.Init( | ||||
| 		&Options{ | ||||
| 			Ledger: ln, | ||||
| 		}, | ||||
| 	) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.True(t, h.OnACLCheck( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		"mochi/info", | ||||
| 		true, | ||||
| 	)) | ||||
|  | ||||
| 	require.False(t, h.OnACLCheck( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		"d/j/f", | ||||
| 		true, | ||||
| 	)) | ||||
|  | ||||
| 	require.True(t, h.OnACLCheck( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		"readonly", | ||||
| 		false, | ||||
| 	)) | ||||
|  | ||||
| 	require.False(t, h.OnACLCheck( | ||||
| 		&mqtt.Client{ | ||||
| 			Properties: mqtt.ClientProperties{ | ||||
| 				Username: []byte("mochi"), | ||||
| 			}, | ||||
| 		}, | ||||
| 		"readonly", | ||||
| 		true, | ||||
| 	)) | ||||
| } | ||||
							
								
								
									
										231
									
								
								hooks/auth/ledger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								hooks/auth/ledger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,231 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"gopkg.in/yaml.v3" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Deny      Access = iota // user cannot access the topic | ||||
| 	ReadOnly                // user can only subscribe to the topic | ||||
| 	WriteOnly               // user can only publish to the topic | ||||
| 	ReadWrite               // user can both publish and subscribe to the topic | ||||
| ) | ||||
|  | ||||
| // Access determines the read/write privileges for an ACL rule. | ||||
| type Access byte | ||||
|  | ||||
| // Users contains a map of access rules for specific users, keyed on username. | ||||
| type Users map[string]UserRule | ||||
|  | ||||
| // UserRule defines a set of access rules for a specific user. | ||||
| type UserRule struct { | ||||
| 	Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user | ||||
| 	Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user | ||||
| 	ACL      Filters `json:"acl,omitempty" yaml:"acl,omitempty"`           // filters to match, if desired | ||||
| 	Disallow bool    `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user | ||||
| } | ||||
|  | ||||
| // AuthRules defines generic access rules applicable to all users. | ||||
| type AuthRules []AuthRule | ||||
|  | ||||
| type AuthRule struct { | ||||
| 	Client   RString `json:"client,omitempty" yaml:"client,omitempty"`     // the id of a connecting client | ||||
| 	Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user | ||||
| 	Remote   RString `json:"remote,omitempty" yaml:"remote,omitempty"`     // remote address or | ||||
| 	Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user | ||||
| 	Allow    bool    `json:"allow,omitempty" yaml:"allow,omitempty"`       // allow or disallow the users | ||||
| } | ||||
|  | ||||
| // ACLRules defines generic topic or filter access rules applicable to all users. | ||||
| type ACLRules []ACLRule | ||||
|  | ||||
| // ACLRule defines access rules for a specific topic or filter. | ||||
| type ACLRule struct { | ||||
| 	Client   RString `json:"client,omitempty" yaml:"client,omitempty"`     // the id of a connecting client | ||||
| 	Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user | ||||
| 	Remote   RString `json:"remote,omitempty" yaml:"remote,omitempty"`     // remote address or | ||||
| 	Filters  Filters `json:"filters,omitempty" yaml:"filters,omitempty"`   // filters to match | ||||
| } | ||||
|  | ||||
| // Filters is a map of Access rules keyed on filter. | ||||
| type Filters map[RString]Access | ||||
|  | ||||
| // RString is a rule value string. | ||||
| type RString string | ||||
|  | ||||
| // Matches returns true if the rule matches a given string. | ||||
| func (r RString) Matches(a string) bool { | ||||
| 	rr := string(r) | ||||
| 	if r == "" || r == "*" || a == rr { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	i := strings.Index(rr, "*") | ||||
| 	if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // FilterMatches returns true if a filter matches a topic rule. | ||||
| func (f RString) FilterMatches(a string) bool { | ||||
| 	_, ok := MatchTopic(string(f), a) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // MatchTopic checks if a given topic matches a filter, accounting for filter | ||||
| // wildcards. Eg. filter /a/b/+/c == topic a/b/d/c. | ||||
| func MatchTopic(filter string, topic string) (elements []string, matched bool) { | ||||
| 	filterParts := strings.Split(filter, "/") | ||||
| 	topicParts := strings.Split(topic, "/") | ||||
|  | ||||
| 	elements = make([]string, 0) | ||||
| 	for i := 0; i < len(filterParts); i++ { | ||||
| 		if i >= len(topicParts) { | ||||
| 			matched = false | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if filterParts[i] == "+" { | ||||
| 			elements = append(elements, topicParts[i]) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if filterParts[i] == "#" { | ||||
| 			matched = true | ||||
| 			elements = append(elements, strings.Join(topicParts[i:], "/")) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if filterParts[i] != topicParts[i] { | ||||
| 			matched = false | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return elements, true | ||||
| } | ||||
|  | ||||
| // Ledger is an auth ledger containing access rules for users and topics. | ||||
| type Ledger struct { | ||||
| 	sync.Mutex `json:"-" yaml:"-"` | ||||
| 	Users      Users     `json:"users" yaml:"users"` | ||||
| 	Auth       AuthRules `json:"auth" yaml:"auth"` | ||||
| 	ACL        ACLRules  `json:"acl" yaml:"acl"` | ||||
| } | ||||
|  | ||||
| // Update updates the internal values of the ledger. | ||||
| func (l *Ledger) Update(ln *Ledger) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	l.Auth = ln.Auth | ||||
| 	l.ACL = ln.ACL | ||||
| } | ||||
|  | ||||
| // AuthOk returns true if the rules indicate the user is allowed to authenticate. | ||||
| func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { | ||||
| 	// If the users map is set, always check for a predefined user first instead | ||||
| 	// of iterating through global rules. | ||||
| 	if l.Users != nil { | ||||
| 		if u, ok := l.Users[string(cl.Properties.Username)]; ok && | ||||
| 			u.Password != "" && | ||||
| 			u.Password == RString(pk.Connect.Password) { | ||||
| 			return 0, !u.Disallow | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// If there's no users map, or no user was found, attempt to find a matching | ||||
| 	// rule (which may also contain a user). | ||||
| 	for n, rule := range l.Auth { | ||||
| 		if rule.Client.Matches(cl.ID) && | ||||
| 			rule.Username.Matches(string(cl.Properties.Username)) && | ||||
| 			rule.Password.Matches(string(pk.Connect.Password)) && | ||||
| 			rule.Remote.Matches(cl.Net.Remote) { | ||||
| 			return n, rule.Allow | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return 0, false | ||||
| } | ||||
|  | ||||
| // ACLOk returns true if the rules indicate the user is allowed to read or write to | ||||
| // a specific filter or topic respectively, based on the write bool. | ||||
| func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { | ||||
| 	// If the users map is set, always check for a predefined user first instead | ||||
| 	// of iterating through global rules. | ||||
| 	if l.Users != nil { | ||||
| 		if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 { | ||||
| 			for filter, access := range u.ACL { | ||||
| 				if filter.FilterMatches(topic) { | ||||
| 					if !write && (access == ReadOnly || access == ReadWrite) { | ||||
| 						return n, true | ||||
| 					} else if write && (access == WriteOnly || access == ReadWrite) { | ||||
| 						return n, true | ||||
| 					} else { | ||||
| 						return n, false | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for n, rule := range l.ACL { | ||||
| 		if rule.Client.Matches(cl.ID) && | ||||
| 			rule.Username.Matches(string(cl.Properties.Username)) && | ||||
| 			rule.Remote.Matches(cl.Net.Remote) { | ||||
| 			if len(rule.Filters) == 0 { | ||||
| 				return n, true | ||||
| 			} | ||||
|  | ||||
| 			for filter, access := range rule.Filters { | ||||
| 				if filter.FilterMatches(topic) { | ||||
| 					if !write && (access == ReadOnly || access == ReadWrite) { | ||||
| 						return n, true | ||||
| 					} else if write && (access == WriteOnly || access == ReadWrite) { | ||||
| 						return n, true | ||||
| 					} else { | ||||
| 						return n, false | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return 0, true | ||||
| } | ||||
|  | ||||
| // ToJSON encodes the values into a JSON string. | ||||
| func (l *Ledger) ToJSON() (data []byte, err error) { | ||||
| 	return json.Marshal(l) | ||||
| } | ||||
|  | ||||
| // ToYAML encodes the values into a YAML string. | ||||
| func (l *Ledger) ToYAML() (data []byte, err error) { | ||||
| 	return yaml.Marshal(l) | ||||
| } | ||||
|  | ||||
| // Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct. | ||||
| func (l *Ledger) Unmarshal(data []byte) error { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	if len(data) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if data[0] == '{' { | ||||
| 		return json.Unmarshal(data, l) | ||||
| 	} | ||||
|  | ||||
| 	return yaml.Unmarshal(data, &l) | ||||
| } | ||||
							
								
								
									
										610
									
								
								hooks/auth/ledger_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										610
									
								
								hooks/auth/ledger_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,610 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	checkLedger = Ledger{ | ||||
| 		Users: Users{ // users are allowed by default | ||||
| 			"mochi-co": { | ||||
| 				Password: "melon", | ||||
| 				ACL: Filters{ | ||||
| 					"d/+/f":      Deny, | ||||
| 					"mochi-co/#": ReadWrite, | ||||
| 					"readonly":   ReadOnly, | ||||
| 				}, | ||||
| 			}, | ||||
| 			"suspended-username": { | ||||
| 				Password: "any", | ||||
| 				Disallow: true, | ||||
| 			}, | ||||
| 			"mochi": { // ACL only, will defer to AuthRules for authentication | ||||
| 				ACL: Filters{ | ||||
| 					"special/mochi": ReadOnly, | ||||
| 					"secret/mochi":  Deny, | ||||
| 					"ignored":       ReadWrite, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Auth: AuthRules{ | ||||
| 			{Username: "banned-user"},                               // never allow specific username | ||||
| 			{Remote: "127.0.0.1", Allow: true},                      // always allow localhost | ||||
| 			{Remote: "123.123.123.123"},                             // disallow any from specific address | ||||
| 			{Username: "not-mochi", Remote: "111.144.155.166"},      // disallow specific username and address | ||||
| 			{Remote: "111.*", Allow: true},                          // allow any in wildcard (that isn't the above username) | ||||
| 			{Username: "mochi", Password: "melon", Allow: true},     // allow matching user/pass | ||||
| 			{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map) | ||||
| 		}, | ||||
| 		ACL: ACLRules{ | ||||
| 			{ | ||||
| 				Username: "mochi", // allow matching user/pass | ||||
| 				Filters: Filters{ | ||||
| 					"a/b/c":     Deny, | ||||
| 					"d/+/f":     Deny, | ||||
| 					"mochi/#":   ReadWrite, | ||||
| 					"updates/#": WriteOnly, | ||||
| 					"readonly":  ReadOnly, | ||||
| 					"ignored":   Deny, | ||||
| 				}, | ||||
| 			}, | ||||
| 			{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost | ||||
| 			{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}},   // allow $SYS access to admin | ||||
| 			{Remote: "001.002.003.004"},                                 // Allow all with no filter | ||||
| 			{Filters: Filters{"$SYS/#": Deny}},                          // Deny $SYS access to all others | ||||
| 		}, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| func TestRStringMatches(t *testing.T) { | ||||
| 	require.True(t, RString("*").Matches("any")) | ||||
| 	require.True(t, RString("*").Matches("")) | ||||
| 	require.True(t, RString("").Matches("any")) | ||||
| 	require.True(t, RString("").Matches("")) | ||||
| 	require.False(t, RString("no").Matches("any")) | ||||
| 	require.False(t, RString("no").Matches("")) | ||||
| } | ||||
|  | ||||
| func TestCanAuthenticate(t *testing.T) { | ||||
| 	tt := []struct { | ||||
| 		desc   string | ||||
| 		client *mqtt.Client | ||||
| 		pk     packets.Packet | ||||
| 		n      int | ||||
| 		ok     bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc: "allow all local 127.0.0.1", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "127.0.0.1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{}}, | ||||
| 			ok: true, | ||||
| 			n:  1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow username/password", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, | ||||
| 			ok: true, | ||||
| 			n:  5, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny username/password", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, | ||||
| 			ok: false, | ||||
| 			n:  0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow all local 127.0.0.1", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "127.0.0.1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, | ||||
| 			ok: true, | ||||
| 			n:  1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow username/password", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, | ||||
| 			ok: true, | ||||
| 			n:  5, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny username/password", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, | ||||
| 			ok: false, | ||||
| 			n:  0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny client from address", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("not-mochi"), | ||||
| 				}, | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "111.144.155.166", | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{}, | ||||
| 			ok: false, | ||||
| 			n:  3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow remote wildcard", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "111.0.0.1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{}, | ||||
| 			ok: true, | ||||
| 			n:  4, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "never allow username", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("banned-user"), | ||||
| 				}, | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "127.0.0.1", | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{}, | ||||
| 			ok: false, | ||||
| 			n:  0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "matching user in users", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi-co"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, | ||||
| 			ok: true, | ||||
| 			n:  0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "never user in users", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("suspended-user"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}}, | ||||
| 			ok: false, | ||||
| 			n:  0, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, d := range tt { | ||||
| 		t.Run(d.desc, func(t *testing.T) { | ||||
| 			n, ok := checkLedger.AuthOk(d.client, d.pk) | ||||
| 			require.Equal(t, d.n, n) | ||||
| 			require.Equal(t, d.ok, ok) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCanACL(t *testing.T) { | ||||
| 	tt := []struct { | ||||
| 		client *mqtt.Client | ||||
| 		desc   string | ||||
| 		topic  string | ||||
| 		n      int | ||||
| 		write  bool | ||||
| 		ok     bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc:   "allow normal write on any other filter", | ||||
| 			client: &mqtt.Client{}, | ||||
| 			topic:  "default/acl/write/access", | ||||
| 			write:  true, | ||||
| 			ok:     true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "allow normal read on any other filter", | ||||
| 			client: &mqtt.Client{}, | ||||
| 			topic:  "default/acl/read/access", | ||||
| 			write:  false, | ||||
| 			ok:     true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny user on literal filter", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "a/b/c", | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny user on partial filter", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "d/j/f", | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow read/write to user path", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "mochi/read/write", | ||||
| 			write: true, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny read on write-only path", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "updates/no/reading", | ||||
| 			write: false, | ||||
| 			ok:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny read on write-only path ext", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "updates/mochi", | ||||
| 			write: false, | ||||
| 			ok:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow read on not-acl path (no #)", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "updates", | ||||
| 			write: false, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow write on write-only path", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "updates/mochi", | ||||
| 			write: true, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny write on read-only path", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "readonly", | ||||
| 			write: true, | ||||
| 			ok:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow read on read-only path", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "readonly", | ||||
| 			write: false, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow $sys access to localhost", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "localhost", | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "$SYS/test", | ||||
| 			write: false, | ||||
| 			ok:    true, | ||||
| 			n:     1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow $sys access to admin", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("admin"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "$SYS/test", | ||||
| 			write: false, | ||||
| 			ok:    true, | ||||
| 			n:     2, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "deny $sys access to all others", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "$SYS/test", | ||||
| 			write: false, | ||||
| 			ok:    false, | ||||
| 			n:     4, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "allow all with no filter", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Net: mqtt.ClientConnection{ | ||||
| 					Remote: "001.002.003.004", | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "any/path", | ||||
| 			write: true, | ||||
| 			ok:    true, | ||||
| 			n:     3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "use users embedded acl deny", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "secret/mochi", | ||||
| 			write: true, | ||||
| 			ok:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "use users embedded acl any", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "any/mochi", | ||||
| 			write: true, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "use users embedded acl write on read-only", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "special/mochi", | ||||
| 			write: true, | ||||
| 			ok:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "use users embedded acl read on read-only", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "special/mochi", | ||||
| 			write: false, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "preference users embedded acl", | ||||
| 			client: &mqtt.Client{ | ||||
| 				Properties: mqtt.ClientProperties{ | ||||
| 					Username: []byte("mochi"), | ||||
| 				}, | ||||
| 			}, | ||||
| 			topic: "ignored", | ||||
| 			write: true, | ||||
| 			ok:    true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, d := range tt { | ||||
| 		t.Run(d.desc, func(t *testing.T) { | ||||
| 			n, ok := checkLedger.ACLOk(d.client, d.topic, d.write) | ||||
| 			require.Equal(t, d.n, n) | ||||
| 			require.Equal(t, d.ok, ok) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMatchTopic(t *testing.T) { | ||||
| 	el, matched := MatchTopic("a/+/c/+", "a/b/c/d") | ||||
| 	require.True(t, matched) | ||||
| 	require.Equal(t, []string{"b", "d"}, el) | ||||
|  | ||||
| 	el, matched = MatchTopic("a/+/+/+", "a/b/c/d") | ||||
| 	require.True(t, matched) | ||||
| 	require.Equal(t, []string{"b", "c", "d"}, el) | ||||
|  | ||||
| 	el, matched = MatchTopic("stuff/#", "stuff/things/yeah") | ||||
| 	require.True(t, matched) | ||||
| 	require.Equal(t, []string{"things/yeah"}, el) | ||||
|  | ||||
| 	el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds") | ||||
| 	require.True(t, matched) | ||||
| 	require.Equal(t, []string{"b", "c/d/as/dds"}, el) | ||||
|  | ||||
| 	el, matched = MatchTopic("test", "test") | ||||
| 	require.True(t, matched) | ||||
| 	require.Equal(t, make([]string, 0), el) | ||||
|  | ||||
| 	el, matched = MatchTopic("things/stuff//", "things/stuff/") | ||||
| 	require.False(t, matched) | ||||
| 	require.Equal(t, make([]string, 0), el) | ||||
|  | ||||
| 	el, matched = MatchTopic("t", "t2") | ||||
| 	require.False(t, matched) | ||||
| 	require.Equal(t, make([]string, 0), el) | ||||
|  | ||||
| 	el, matched = MatchTopic(" ", "  ") | ||||
| 	require.False(t, matched) | ||||
| 	require.Equal(t, make([]string, 0), el) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	ledgerStruct = Ledger{ | ||||
| 		Users: Users{ | ||||
| 			"mochi": { | ||||
| 				Password: "peach", | ||||
| 				ACL: Filters{ | ||||
| 					"readonly": ReadOnly, | ||||
| 					"deny":     Deny, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Auth: AuthRules{ | ||||
| 			{ | ||||
| 				Client:   "*", | ||||
| 				Username: "mochi-co", | ||||
| 				Password: "melon", | ||||
| 				Remote:   "192.168.1.*", | ||||
| 				Allow:    true, | ||||
| 			}, | ||||
| 		}, | ||||
| 		ACL: ACLRules{ | ||||
| 			{ | ||||
| 				Client:   "*", | ||||
| 				Username: "mochi-co", | ||||
| 				Remote:   "127.*", | ||||
| 				Filters: Filters{ | ||||
| 					"readonly":  ReadOnly, | ||||
| 					"writeonly": WriteOnly, | ||||
| 					"readwrite": ReadWrite, | ||||
| 					"deny":      Deny, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`) | ||||
| 	ledgerYAML = []byte(`users: | ||||
|     mochi: | ||||
|         password: peach | ||||
|         acl: | ||||
|             deny: 0 | ||||
|             readonly: 1 | ||||
| auth: | ||||
|     - client: '*' | ||||
|       username: mochi-co | ||||
|       remote: 192.168.1.* | ||||
|       password: melon | ||||
|       allow: true | ||||
| acl: | ||||
|     - client: '*' | ||||
|       username: mochi-co | ||||
|       remote: 127.* | ||||
|       filters: | ||||
|         deny: 0 | ||||
|         readonly: 1 | ||||
|         readwrite: 3 | ||||
|         writeonly: 2 | ||||
| `) | ||||
| ) | ||||
|  | ||||
| func TestLedgerUpdate(t *testing.T) { | ||||
| 	old := &Ledger{ | ||||
| 		Auth: AuthRules{ | ||||
| 			{Remote: "127.0.0.1", Allow: true}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	new := &Ledger{ | ||||
| 		Auth: AuthRules{ | ||||
| 			{Remote: "127.0.0.1", Allow: true}, | ||||
| 			{Remote: "192.168.*", Allow: true}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	old.Update(new) | ||||
| 	require.Len(t, old.Auth, 2) | ||||
| 	require.Equal(t, RString("192.168.*"), old.Auth[1].Remote) | ||||
| 	require.NotSame(t, new, old) | ||||
| } | ||||
|  | ||||
| func TestLedgerToJSON(t *testing.T) { | ||||
| 	data, err := ledgerStruct.ToJSON() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, ledgerJSON, data) | ||||
| } | ||||
|  | ||||
| func TestLedgerToYAML(t *testing.T) { | ||||
| 	data, err := ledgerStruct.ToYAML() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, ledgerYAML, data) | ||||
| } | ||||
|  | ||||
| func TestLedgerUnmarshalFromYAML(t *testing.T) { | ||||
| 	l := new(Ledger) | ||||
| 	err := l.Unmarshal(ledgerYAML) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, &ledgerStruct, l) | ||||
| 	require.NotSame(t, l, &ledgerStruct) | ||||
| } | ||||
|  | ||||
| func TestLedgerUnmarshalFromJSON(t *testing.T) { | ||||
| 	l := new(Ledger) | ||||
| 	err := l.Unmarshal(ledgerJSON) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, &ledgerStruct, l) | ||||
| 	require.NotSame(t, l, &ledgerStruct) | ||||
| } | ||||
|  | ||||
| func TestLedgerUnmarshalNil(t *testing.T) { | ||||
| 	l := new(Ledger) | ||||
| 	err := l.Unmarshal([]byte{}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, new(Ledger), l) | ||||
| } | ||||
							
								
								
									
										250
									
								
								hooks/debug/debug.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								hooks/debug/debug.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,250 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package debug | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // Options contains configuration settings for the debug output. | ||||
| type Options struct { | ||||
| 	ShowPacketData bool // include decoded packet data (default false) | ||||
| 	ShowPings      bool // show ping requests and responses (default false) | ||||
| 	ShowPasswords  bool // show connecting user passwords (default false) | ||||
| } | ||||
|  | ||||
| // Hook is a debugging hook which logs additional low-level information from the server. | ||||
| type Hook struct { | ||||
| 	mqtt.HookBase | ||||
| 	config *Options | ||||
| 	Log    *zerolog.Logger | ||||
| } | ||||
|  | ||||
| // ID returns the ID of the hook. | ||||
| func (h *Hook) ID() string { | ||||
| 	return "debug" | ||||
| } | ||||
|  | ||||
| // Provides indicates that this hook provides all methods. | ||||
| func (h *Hook) Provides(b byte) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // Init is called when the hook is initialized. | ||||
| func (h *Hook) Init(config any) error { | ||||
| 	if _, ok := config.(*Options); !ok && config != nil { | ||||
| 		return mqtt.ErrInvalidConfigType | ||||
| 	} | ||||
|  | ||||
| 	if config == nil { | ||||
| 		config = new(Options) | ||||
| 	} | ||||
|  | ||||
| 	h.config = config.(*Options) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetOpts is called when the hook receives inheritable server parameters. | ||||
| func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) { | ||||
| 	h.Log = l | ||||
| 	h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send() | ||||
| } | ||||
|  | ||||
| // Stop is called when the hook is stopped. | ||||
| func (h *Hook) Stop() error { | ||||
| 	h.Log.Debug().Str("method", "Stop").Send() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // OnStarted is called when the server starts. | ||||
| func (h *Hook) OnStarted() { | ||||
| 	h.Log.Debug().Str("method", "OnStarted").Send() | ||||
| } | ||||
|  | ||||
| // OnStopped is called when the server stops. | ||||
| func (h *Hook) OnStopped() { | ||||
| 	h.Log.Debug().Str("method", "OnStopped").Send() | ||||
| } | ||||
|  | ||||
| // OnPacketRead is called when a new packet is received from a client. | ||||
| func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { | ||||
| 		return pk, nil | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) | ||||
|  | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| // OnPacketSent is called when a packet is sent to a client. | ||||
| func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) { | ||||
| 	if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID) | ||||
| } | ||||
|  | ||||
| // OnRetainMessage is called when a published message is retained (or retain deleted/modified). | ||||
| func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic") | ||||
| } | ||||
|  | ||||
| // OnQosPublish is called when a publish packet with Qos is issued to a subscriber. | ||||
| func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out") | ||||
| } | ||||
|  | ||||
| // OnQosComplete is called when the Qos flow for a message has been completed. | ||||
| func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete") | ||||
| } | ||||
|  | ||||
| // OnQosDropped is called the Qos flow for a message expires. | ||||
| func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped") | ||||
| } | ||||
|  | ||||
| // OnLWTSent is called when a will message has been issued from a disconnecting client. | ||||
| func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client") | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired is called when the server clears expired retained messages. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired") | ||||
| } | ||||
|  | ||||
| // OnClientExpired is called when the server clears an expired client. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired") | ||||
| } | ||||
|  | ||||
| // StoredClients is called when the server restores clients from a store. | ||||
| func (h *Hook) StoredClients() (v []storage.Client, err error) { | ||||
| 	h.Log.Debug(). | ||||
| 		Str("method", "StoredClients"). | ||||
| 		Send() | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredClients is called when the server restores subscriptions from a store. | ||||
| func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	h.Log.Debug(). | ||||
| 		Str("method", "StoredSubscriptions"). | ||||
| 		Send() | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredClients is called when the server restores retained messages from a store. | ||||
| func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	h.Log.Debug(). | ||||
| 		Str("method", "StoredRetainedMessages"). | ||||
| 		Send() | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredClients is called when the server restores inflight messages from a store. | ||||
| func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	h.Log.Debug(). | ||||
| 		Str("method", "StoredInflightMessages"). | ||||
| 		Send() | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredClients is called when the server restores system info from a store. | ||||
| func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	h.Log.Debug(). | ||||
| 		Str("method", "StoredClients"). | ||||
| 		Send() | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // packetMeta adds additional type-specific metadata to the debug logs. | ||||
| func (h *Hook) packetMeta(pk packets.Packet) map[string]any { | ||||
| 	m := map[string]any{} | ||||
| 	switch pk.FixedHeader.Type { | ||||
| 	case packets.Connect: | ||||
| 		m["id"] = pk.Connect.ClientIdentifier | ||||
| 		m["clean"] = pk.Connect.Clean | ||||
| 		m["keepalive"] = pk.Connect.Keepalive | ||||
| 		m["version"] = pk.ProtocolVersion | ||||
| 		m["username"] = string(pk.Connect.Username) | ||||
| 		if h.config.ShowPasswords { | ||||
| 			m["password"] = string(pk.Connect.Password) | ||||
| 		} | ||||
| 		if pk.Connect.WillFlag { | ||||
| 			m["will_topic"] = pk.Connect.WillTopic | ||||
| 			m["will_payload"] = string(pk.Connect.WillPayload) | ||||
| 		} | ||||
| 	case packets.Publish: | ||||
| 		m["topic"] = pk.TopicName | ||||
| 		m["payload"] = string(pk.Payload) | ||||
| 		m["raw"] = pk.Payload | ||||
| 		m["qos"] = pk.FixedHeader.Qos | ||||
| 		m["id"] = pk.PacketID | ||||
| 	case packets.Connack: | ||||
| 		fallthrough | ||||
| 	case packets.Disconnect: | ||||
| 		fallthrough | ||||
| 	case packets.Puback: | ||||
| 		fallthrough | ||||
| 	case packets.Pubrec: | ||||
| 		fallthrough | ||||
| 	case packets.Pubrel: | ||||
| 		fallthrough | ||||
| 	case packets.Pubcomp: | ||||
| 		m["id"] = pk.PacketID | ||||
| 		m["reason"] = int(pk.ReasonCode) | ||||
| 		if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 { | ||||
| 			m["reason_string"] = pk.Properties.ReasonString | ||||
| 		} | ||||
| 	case packets.Subscribe: | ||||
| 		f := map[string]int{} | ||||
| 		ids := map[string]int{} | ||||
| 		for _, v := range pk.Filters { | ||||
| 			f[v.Filter] = int(v.Qos) | ||||
| 			ids[v.Filter] = v.Identifier | ||||
| 		} | ||||
| 		m["filters"] = f | ||||
| 		m["subids"] = f | ||||
|  | ||||
| 	case packets.Unsubscribe: | ||||
| 		f := []string{} | ||||
| 		for _, v := range pk.Filters { | ||||
| 			f = append(f, v.Filter) | ||||
| 		} | ||||
| 		m["filters"] = f | ||||
| 	case packets.Suback: | ||||
| 		fallthrough | ||||
| 	case packets.Unsuback: | ||||
| 		r := []int{} | ||||
| 		for _, v := range pk.ReasonCodes { | ||||
| 			r = append(r, int(v)) | ||||
| 		} | ||||
| 		m["reasons"] = r | ||||
| 	case packets.Auth: | ||||
| 		// tbd | ||||
| 	} | ||||
|  | ||||
| 	if h.config.ShowPacketData { | ||||
| 		m["packet"] = pk | ||||
| 	} | ||||
|  | ||||
| 	return m | ||||
| } | ||||
							
								
								
									
										484
									
								
								hooks/storage/badger/badger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										484
									
								
								hooks/storage/badger/badger.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,484 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package badger | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/timshannon/badgerhold" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	// defaultDbFile is the default file path for the badger db file. | ||||
| 	defaultDbFile = ".badger" | ||||
| ) | ||||
|  | ||||
| // clientKey returns a primary key for a client. | ||||
| func clientKey(cl *mqtt.Client) string { | ||||
| 	return cl.ID | ||||
| } | ||||
|  | ||||
| // subscriptionKey returns a primary key for a subscription. | ||||
| func subscriptionKey(cl *mqtt.Client, filter string) string { | ||||
| 	return storage.SubscriptionKey + "_" + cl.ID + ":" + filter | ||||
| } | ||||
|  | ||||
| // retainedKey returns a primary key for a retained message. | ||||
| func retainedKey(topic string) string { | ||||
| 	return storage.RetainedKey + "_" + topic | ||||
| } | ||||
|  | ||||
| // inflightKey returns a primary key for an inflight message. | ||||
| func inflightKey(cl *mqtt.Client, pk packets.Packet) string { | ||||
| 	return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID() | ||||
| } | ||||
|  | ||||
| // sysInfoKey returns a primary key for system info. | ||||
| func sysInfoKey() string { | ||||
| 	return storage.SysInfoKey | ||||
| } | ||||
|  | ||||
| // Options contains configuration settings for the BadgerDB instance. | ||||
| type Options struct { | ||||
| 	Options *badgerhold.Options | ||||
| 	Path    string | ||||
| } | ||||
|  | ||||
| // Hook is a persistent storage hook based using BadgerDB file store as a backend. | ||||
| type Hook struct { | ||||
| 	mqtt.HookBase | ||||
| 	config *Options          // options for configuring the BadgerDB instance. | ||||
| 	db     *badgerhold.Store // the BadgerDB instance. | ||||
| } | ||||
|  | ||||
| // ID returns the id of the hook. | ||||
| func (h *Hook) ID() string { | ||||
| 	return "badger-db" | ||||
| } | ||||
|  | ||||
| // Provides indicates which hook methods this hook provides. | ||||
| func (h *Hook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnSessionEstablished, | ||||
| 		mqtt.OnDisconnect, | ||||
| 		mqtt.OnSubscribed, | ||||
| 		mqtt.OnUnsubscribed, | ||||
| 		mqtt.OnRetainMessage, | ||||
| 		mqtt.OnWillSent, | ||||
| 		mqtt.OnQosPublish, | ||||
| 		mqtt.OnQosComplete, | ||||
| 		mqtt.OnQosDropped, | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| 		mqtt.StoredSubscriptions, | ||||
| 		mqtt.StoredSysInfo, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| // Init initializes and connects to the badger instance. | ||||
| func (h *Hook) Init(config any) error { | ||||
| 	if _, ok := config.(*Options); !ok && config != nil { | ||||
| 		return mqtt.ErrInvalidConfigType | ||||
| 	} | ||||
|  | ||||
| 	if config == nil { | ||||
| 		config = new(Options) | ||||
| 	} | ||||
|  | ||||
| 	h.config = config.(*Options) | ||||
| 	if h.config.Path == "" { | ||||
| 		h.config.Path = defaultDbFile | ||||
| 	} | ||||
|  | ||||
| 	options := badgerhold.DefaultOptions | ||||
| 	options.Dir = h.config.Path | ||||
| 	options.ValueDir = h.config.Path | ||||
| 	options.Logger = h | ||||
|  | ||||
| 	var err error | ||||
| 	h.db, err = badgerhold.Open(options) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Stop closes the badger instance. | ||||
| func (h *Hook) Stop() error { | ||||
| 	return h.db.Close() | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished adds a client to the store when their session is established. | ||||
| func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // OnWillSent is called when a client sends a will message and the will message is removed | ||||
| // from the client record. | ||||
| func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // updateClient writes the client data to the store. | ||||
| func (h *Hook) updateClient(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := cl.Properties.Props.Copy(false) | ||||
| 	in := &storage.Client{ | ||||
| 		ID:              clientKey(cl), | ||||
| 		T:               storage.ClientKey, | ||||
| 		Remote:          cl.Net.Remote, | ||||
| 		Listener:        cl.Net.Listener, | ||||
| 		Username:        cl.Properties.Username, | ||||
| 		Clean:           cl.Properties.Clean, | ||||
| 		ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 		Properties: storage.ClientProperties{ | ||||
| 			SessionExpiryInterval: props.SessionExpiryInterval, | ||||
| 			AuthenticationMethod:  props.AuthenticationMethod, | ||||
| 			AuthenticationData:    props.AuthenticationData, | ||||
| 			RequestProblemInfo:    props.RequestProblemInfo, | ||||
| 			RequestResponseInfo:   props.RequestResponseInfo, | ||||
| 			ReceiveMaximum:        props.ReceiveMaximum, | ||||
| 			TopicAliasMaximum:     props.TopicAliasMaximum, | ||||
| 			User:                  props.User, | ||||
| 			MaximumPacketSize:     props.MaximumPacketSize, | ||||
| 		}, | ||||
| 		Will: storage.ClientWill(cl.Properties.Will), | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Upsert(in.ID, in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnDisconnect removes a client from the store if their session has expired. | ||||
| func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h.updateClient(cl) | ||||
|  | ||||
| 	if !expire { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSubscribed adds one or more client subscriptions to the store. | ||||
| func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 		} | ||||
|  | ||||
| 		err := h.db.Upsert(in.ID, in) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnUnsubscribed removes one or more client subscriptions from the store. | ||||
| func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription)) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainMessage adds a retained message for a topic to the store. | ||||
| func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r == -1 { | ||||
| 		err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message)) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data") | ||||
| 		} | ||||
|  | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          retainedKey(pk.TopicName), | ||||
| 		T:           storage.RetainedKey, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Created:     pk.Created, | ||||
| 		Origin:      pk.Origin, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Upsert(in.ID, in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosPublish adds or updates an inflight message in the store. | ||||
| func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          inflightKey(cl, pk), | ||||
| 		T:           storage.InflightKey, | ||||
| 		Origin:      pk.Origin, | ||||
| 		PacketID:    pk.PacketID, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Sent:        sent, | ||||
| 		Created:     pk.Created, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Upsert(in.ID, in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosComplete removes a resolved inflight message from the store. | ||||
| func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Delete(inflightKey(cl, pk), new(storage.Message)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosDropped removes a dropped inflight message from the store. | ||||
| func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosComplete(cl, pk) | ||||
| } | ||||
|  | ||||
| // OnSysInfoTick stores the latest system info in the store. | ||||
| func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	in := &storage.SystemInfo{ | ||||
| 		ID:   sysInfoKey(), | ||||
| 		T:    storage.SysInfoKey, | ||||
| 		Info: *sys, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Upsert(in.ID, in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, m := range v { | ||||
| 		if m.Created < expiry || m.Created == 0 { | ||||
| 			err := h.db.Delete(m.ID, new(storage.Message)) | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	err := h.db.Delete(retainedKey(filter), new(storage.Message)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	err := h.db.Delete(clientKey(cl), new(storage.Client)) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // StoredClients returns all stored clients from the store. | ||||
| func (h *Hook) StoredClients() (v []storage.Client, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSubscriptions returns all stored subscriptions from the store. | ||||
| func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredRetainedMessages returns all stored retained messages from the store. | ||||
| func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredInflightMessages returns all stored inflight messages from the store. | ||||
| func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSysInfo returns the system info from the store. | ||||
| func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Get(storage.SysInfoKey, &v) | ||||
| 	if err != nil && !errors.Is(err, badgerhold.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // Errorf satisfies the badger interface for an error logger. | ||||
| func (h *Hook) Errorf(m string, v ...interface{}) { | ||||
| 	h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) | ||||
| } | ||||
|  | ||||
| // Warningf satisfies the badger interface for a warning logger. | ||||
| func (h *Hook) Warningf(m string, v ...interface{}) { | ||||
| 	h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) | ||||
| } | ||||
|  | ||||
| // Infof satisfies the badger interface for an info logger. | ||||
| func (h *Hook) Infof(m string, v ...interface{}) { | ||||
| 	h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) | ||||
| } | ||||
|  | ||||
| // Debugf satisfies the badger interface for a debug logger. | ||||
| func (h *Hook) Debugf(m string, v ...interface{}) { | ||||
| 	h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...) | ||||
| } | ||||
							
								
								
									
										695
									
								
								hooks/storage/badger/badger_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										695
									
								
								hooks/storage/badger/badger_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,695 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package badger | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/asdine/storm/v3" | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/timshannon/badgerhold" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) | ||||
|  | ||||
| 	client = &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} | ||||
| ) | ||||
|  | ||||
| func teardown(t *testing.T, path string, h *Hook) { | ||||
| 	h.Stop() | ||||
| 	h.db.Badger().Close() | ||||
| 	err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1)) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientKey(t *testing.T) { | ||||
| 	k := clientKey(&mqtt.Client{ID: "cl1"}) | ||||
| 	require.Equal(t, "cl1", k) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionKey(t *testing.T) { | ||||
| 	k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") | ||||
| 	require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestRetainedKey(t *testing.T) { | ||||
| 	k := retainedKey("a/b/c") | ||||
| 	require.Equal(t, storage.RetainedKey+"_a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestInflightKey(t *testing.T) { | ||||
| 	k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) | ||||
| 	require.Equal(t, storage.InflightKey+"_cl1:1", k) | ||||
| } | ||||
|  | ||||
| func TestSysInfoKey(t *testing.T) { | ||||
| 	require.Equal(t, storage.SysInfoKey, sysInfoKey()) | ||||
| } | ||||
|  | ||||
| func TestID(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.Equal(t, "badger-db", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestProvides(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.True(t, h.Provides(mqtt.OnSessionEstablished)) | ||||
| 	require.True(t, h.Provides(mqtt.OnDisconnect)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnUnsubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnRetainMessage)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosPublish)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosComplete)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosDropped)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSysInfoTick)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredClients)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredInflightMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredRetainedMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSubscriptions)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSysInfo)) | ||||
| 	require.False(t, h.Provides(mqtt.OnACLCheck)) | ||||
| 	require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) | ||||
| } | ||||
|  | ||||
| func TestInitBadConfig(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(map[string]any{}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestInitUseDefaults(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	require.Equal(t, defaultDbFile, h.config.Path) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.Get(clientKey(client), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
| 	require.Equal(t, client.Properties.Username, r.Username) | ||||
| 	require.Equal(t, client.Properties.Clean, r.Clean) | ||||
| 	require.Equal(t, client.Net.Remote, r.Remote) | ||||
| 	require.Equal(t, client.Net.Listener, r.Listener) | ||||
| 	require.NotSame(t, client, r) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| 	r2 := new(storage.Client) | ||||
| 	err = h.db.Get(clientKey(client), r2) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, true) | ||||
| 	r3 := new(storage.Client) | ||||
| 	err = h.db.Get(clientKey(client), r3) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, badgerhold.ErrNotFound, err) | ||||
| 	require.Empty(t, r3.ID) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpired(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	cl := &mqtt.Client{ID: "cl1"} | ||||
| 	clientKey := clientKey(cl) | ||||
|  | ||||
| 	err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.Get(clientKey, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, cl.ID, r.ID) | ||||
|  | ||||
| 	h.OnClientExpired(cl) | ||||
| 	err = h.db.Get(clientKey, r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, badgerhold.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnWillSent(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	c1 := client | ||||
| 	c1.Properties.Will.Flag = 1 | ||||
| 	h.OnWillSent(c1, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.Get(clientKey(client), r) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, uint32(1), r.Will.Flag) | ||||
| 	require.NotSame(t, client, r) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| 	r := new(storage.Subscription) | ||||
|  | ||||
| 	err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.Client) | ||||
| 	require.Equal(t, pkf.Filters[0].Filter, r.Filter) | ||||
| 	require.Equal(t, byte(0), r.Qos) | ||||
|  | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| 	err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, badgerhold.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageThenUnset(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, 1) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.Get(retainedKey(pk.TopicName), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	err = h.db.Get(retainedKey(pk.TopicName), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||
|  | ||||
| 	// coverage: delete deleted | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	err = h.db.Get(retainedKey(pk.TopicName), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpired(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	m := &storage.Message{ | ||||
| 		ID:        retainedKey("a/b/c"), | ||||
| 		T:         storage.RetainedKey, | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Upsert(m.ID, m) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.Get(m.ID, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, m.TopicName, r.TopicName) | ||||
|  | ||||
| 	h.OnRetainedExpired(m.TopicName) | ||||
| 	err = h.db.Get(m.ID, r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishThenQOSComplete(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 			Qos:    2, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosPublish(client, pk, time.Now().Unix(), 0) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.Get(inflightKey(client, pk), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	// ensure dates are properly saved | ||||
| 	require.True(t, r.Sent > 0) | ||||
| 	require.True(t, time.Now().Unix()-1 < r.Sent) | ||||
|  | ||||
| 	// OnQosDropped is a passthrough to OnQosComplete here | ||||
| 	h.OnQosDropped(client, pk) | ||||
| 	err = h.db.Get(inflightKey(client, pk), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, badgerhold.ErrNotFound) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey)) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	require.Len(t, v, 1) | ||||
| 	require.Equal(t, "i1", v[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	info := &system.Info{ | ||||
| 		Version:       "2.0.0", | ||||
| 		BytesReceived: 100, | ||||
| 	} | ||||
|  | ||||
| 	h.OnSysInfoTick(info) | ||||
|  | ||||
| 	r := new(storage.SystemInfo) | ||||
| 	err = h.db.Get(storage.SysInfoKey, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, info.Version, r.Version) | ||||
| 	require.Equal(t, info.BytesReceived, r.BytesReceived) | ||||
| 	require.NotSame(t, info, r) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTickNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTickClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
|  | ||||
| func TestStoredClients(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with clients | ||||
| 	err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "cl1", r[0].ID) | ||||
| 	require.Equal(t, "cl2", r[1].ID) | ||||
| 	require.Equal(t, "cl3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredClientsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptions(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with subscriptions | ||||
| 	err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "sub1", r[0].ID) | ||||
| 	require.Equal(t, "sub2", r[1].ID) | ||||
| 	require.Equal(t, "sub3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptionsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessages(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "m1", r[0].ID) | ||||
| 	require.Equal(t, "m2", r[1].ID) | ||||
| 	require.Equal(t, "m3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessagesNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessages(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "i1", r[0].ID) | ||||
| 	require.Equal(t, "i2", r[1].ID) | ||||
| 	require.Equal(t, "i3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessagesNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfo(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{ | ||||
| 		ID: storage.SysInfoKey, | ||||
| 		Info: system.Info{ | ||||
| 			Version: "2.0.0", | ||||
| 		}, | ||||
| 		T: storage.SysInfoKey, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "2.0.0", r.Info.Version) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfoNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestErrorf(t *testing.T) { | ||||
| 	// coverage: one day check log hook | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.Errorf("test", 1, 2, 3) | ||||
| } | ||||
|  | ||||
| func TestWarningf(t *testing.T) { | ||||
| 	// coverage: one day check log hook | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.Warningf("test", 1, 2, 3) | ||||
| } | ||||
|  | ||||
| func TestInfof(t *testing.T) { | ||||
| 	// coverage: one day check log hook | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.Infof("test", 1, 2, 3) | ||||
| } | ||||
|  | ||||
| func TestDebugf(t *testing.T) { | ||||
| 	// coverage: one day check log hook | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.Debugf("test", 1, 2, 3) | ||||
| } | ||||
							
								
								
									
										486
									
								
								hooks/storage/bolt/bolt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										486
									
								
								hooks/storage/bolt/bolt.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,486 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| // package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead. | ||||
| package bolt | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	sgob "github.com/asdine/storm/codec/gob" | ||||
| 	"github.com/asdine/storm/v3" | ||||
| 	"go.etcd.io/bbolt" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	// defaultDbFile is the default file path for the boltdb file. | ||||
| 	defaultDbFile = "bolt.db" | ||||
|  | ||||
| 	// defaultTimeout is the default time to hold a connection to the file. | ||||
| 	defaultTimeout = 250 * time.Millisecond | ||||
| ) | ||||
|  | ||||
| // clientKey returns a primary key for a client. | ||||
| func clientKey(cl *mqtt.Client) string { | ||||
| 	return cl.ID | ||||
| } | ||||
|  | ||||
| // subscriptionKey returns a primary key for a subscription. | ||||
| func subscriptionKey(cl *mqtt.Client, filter string) string { | ||||
| 	return storage.SubscriptionKey + "_" + cl.ID + ":" + filter | ||||
| } | ||||
|  | ||||
| // retainedKey returns a primary key for a retained message. | ||||
| func retainedKey(topic string) string { | ||||
| 	return storage.RetainedKey + "_" + topic | ||||
| } | ||||
|  | ||||
| // inflightKey returns a primary key for an inflight message. | ||||
| func inflightKey(cl *mqtt.Client, pk packets.Packet) string { | ||||
| 	return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID() | ||||
| } | ||||
|  | ||||
| // sysInfoKey returns a primary key for system info. | ||||
| func sysInfoKey() string { | ||||
| 	return storage.SysInfoKey | ||||
| } | ||||
|  | ||||
| // Options contains configuration settings for the bolt instance. | ||||
| type Options struct { | ||||
| 	Options *bbolt.Options | ||||
| 	Path    string | ||||
| } | ||||
|  | ||||
| // Hook is a persistent storage hook based using boltdb file store as a backend. | ||||
| type Hook struct { | ||||
| 	mqtt.HookBase | ||||
| 	config *Options  // options for configuring the boltdb instance. | ||||
| 	db     *storm.DB // the boltdb instance. | ||||
| } | ||||
|  | ||||
| // ID returns the id of the hook. | ||||
| func (h *Hook) ID() string { | ||||
| 	return "bolt-db" | ||||
| } | ||||
|  | ||||
| // Provides indicates which hook methods this hook provides. | ||||
| func (h *Hook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnSessionEstablished, | ||||
| 		mqtt.OnDisconnect, | ||||
| 		mqtt.OnSubscribed, | ||||
| 		mqtt.OnUnsubscribed, | ||||
| 		mqtt.OnRetainMessage, | ||||
| 		mqtt.OnWillSent, | ||||
| 		mqtt.OnQosPublish, | ||||
| 		mqtt.OnQosComplete, | ||||
| 		mqtt.OnQosDropped, | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| 		mqtt.StoredSubscriptions, | ||||
| 		mqtt.StoredSysInfo, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| // Init initializes and connects to the boltdb instance. | ||||
| func (h *Hook) Init(config any) error { | ||||
| 	if _, ok := config.(*Options); !ok && config != nil { | ||||
| 		return mqtt.ErrInvalidConfigType | ||||
| 	} | ||||
|  | ||||
| 	if config == nil { | ||||
| 		config = new(Options) | ||||
| 	} | ||||
|  | ||||
| 	h.config = config.(*Options) | ||||
| 	if h.config.Options == nil { | ||||
| 		h.config.Options = &bbolt.Options{ | ||||
| 			Timeout: defaultTimeout, | ||||
| 		} | ||||
| 	} | ||||
| 	if h.config.Path == "" { | ||||
| 		h.config.Path = defaultDbFile | ||||
| 	} | ||||
|  | ||||
| 	var err error | ||||
| 	h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Stop closes the boltdb instance. | ||||
| func (h *Hook) Stop() error { | ||||
| 	return h.db.Close() | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished adds a client to the store when their session is established. | ||||
| func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // OnWillSent is called when a client sends a will message and the will message is removed | ||||
| // from the client record. | ||||
| func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // updateClient writes the client data to the store. | ||||
| func (h *Hook) updateClient(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := cl.Properties.Props.Copy(false) | ||||
| 	in := &storage.Client{ | ||||
| 		ID:              clientKey(cl), | ||||
| 		T:               storage.ClientKey, | ||||
| 		Remote:          cl.Net.Remote, | ||||
| 		Listener:        cl.Net.Listener, | ||||
| 		Username:        cl.Properties.Username, | ||||
| 		Clean:           cl.Properties.Clean, | ||||
| 		ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 		Properties: storage.ClientProperties{ | ||||
| 			SessionExpiryInterval: props.SessionExpiryInterval, | ||||
| 			AuthenticationMethod:  props.AuthenticationMethod, | ||||
| 			AuthenticationData:    props.AuthenticationData, | ||||
| 			RequestProblemInfo:    props.RequestProblemInfo, | ||||
| 			RequestResponseInfo:   props.RequestResponseInfo, | ||||
| 			ReceiveMaximum:        props.ReceiveMaximum, | ||||
| 			TopicAliasMaximum:     props.TopicAliasMaximum, | ||||
| 			User:                  props.User, | ||||
| 			MaximumPacketSize:     props.MaximumPacketSize, | ||||
| 		}, | ||||
| 		Will: storage.ClientWill(cl.Properties.Will), | ||||
| 	} | ||||
| 	err := h.db.Save(in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnDisconnect removes a client from the store if they were using a clean session. | ||||
| func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !expire { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSubscribed adds one or more client subscriptions to the store. | ||||
| func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 		} | ||||
| 		err := h.db.Save(in) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err). | ||||
| 				Str("client", cl.ID). | ||||
| 				Interface("data", in). | ||||
| 				Msg("failed to save subscription data") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnUnsubscribed removes one or more client subscriptions from the store. | ||||
| func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		err := h.db.DeleteStruct(&storage.Subscription{ | ||||
| 			ID: subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err). | ||||
| 				Str("id", subscriptionKey(cl, pk.Filters[i].Filter)). | ||||
| 				Msg("failed to delete client") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainMessage adds a retained message for a topic to the store. | ||||
| func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r == -1 { | ||||
| 		err := h.db.DeleteStruct(&storage.Message{ | ||||
| 			ID: retainedKey(pk.TopicName), | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err). | ||||
| 				Str("id", retainedKey(pk.TopicName)). | ||||
| 				Msg("failed to delete retained publish") | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          retainedKey(pk.TopicName), | ||||
| 		T:           storage.RetainedKey, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Created:     pk.Created, | ||||
| 		Origin:      pk.Origin, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
| 	err := h.db.Save(in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err). | ||||
| 			Str("client", cl.ID). | ||||
| 			Interface("data", in). | ||||
| 			Msg("failed to save retained publish data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosPublish adds or updates an inflight message in the store. | ||||
| func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          inflightKey(cl, pk), | ||||
| 		T:           storage.InflightKey, | ||||
| 		Origin:      pk.Origin, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Sent:        sent, | ||||
| 		Created:     pk.Created, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Save(in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err). | ||||
| 			Str("client", cl.ID). | ||||
| 			Interface("data", in). | ||||
| 			Msg("failed to save qos inflight data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosComplete removes a resolved inflight message from the store. | ||||
| func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.DeleteStruct(&storage.Message{ | ||||
| 		ID: inflightKey(cl, pk), | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err). | ||||
| 			Str("id", inflightKey(cl, pk)). | ||||
| 			Msg("failed to delete inflight data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosDropped removes a dropped inflight message from the store. | ||||
| func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosComplete(cl, pk) | ||||
| } | ||||
|  | ||||
| // OnSysInfoTick stores the latest system info in the store. | ||||
| func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	in := &storage.SystemInfo{ | ||||
| 		ID:   sysInfoKey(), | ||||
| 		T:    storage.SysInfoKey, | ||||
| 		Info: *sys, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.Save(in) | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err). | ||||
| 			Interface("data", in). | ||||
| 			Msg("failed to save $SYS data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the | ||||
| // provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err := h.db.Find("T", storage.InflightKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, m := range v { | ||||
| 		if m.Created < expiry || m.Created == 0 { | ||||
| 			err := h.db.DeleteStruct(&storage.Message{ID: m.ID}) | ||||
| 			if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 				h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data") | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)}) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // StoredClients returns all stored clients from the store. | ||||
| func (h *Hook) StoredClients() (v []storage.Client, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find("T", storage.ClientKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSubscriptions returns all stored subscriptions from the store. | ||||
| func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find("T", storage.SubscriptionKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredRetainedMessages returns all stored retained messages from the store. | ||||
| func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find("T", storage.RetainedKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredInflightMessages returns all stored inflight messages from the store. | ||||
| func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Find("T", storage.InflightKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSysInfo returns the system info from the store. | ||||
| func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.One("ID", storage.SysInfoKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
							
								
								
									
										730
									
								
								hooks/storage/bolt/bolt_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										730
									
								
								hooks/storage/bolt/bolt_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,730 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package bolt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/asdine/storm/v3" | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) | ||||
|  | ||||
| 	client = &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} | ||||
| ) | ||||
|  | ||||
| func teardown(t *testing.T, path string, h *Hook) { | ||||
| 	h.Stop() | ||||
| 	err := os.Remove(path) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestClientKey(t *testing.T) { | ||||
| 	k := clientKey(&mqtt.Client{ID: "cl1"}) | ||||
| 	require.Equal(t, "cl1", k) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionKey(t *testing.T) { | ||||
| 	k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") | ||||
| 	require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestRetainedKey(t *testing.T) { | ||||
| 	k := retainedKey("a/b/c") | ||||
| 	require.Equal(t, storage.RetainedKey+"_a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestInflightKey(t *testing.T) { | ||||
| 	k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) | ||||
| 	require.Equal(t, storage.InflightKey+"_cl1:1", k) | ||||
| } | ||||
|  | ||||
| func TestSysInfoKey(t *testing.T) { | ||||
| 	require.Equal(t, storage.SysInfoKey, sysInfoKey()) | ||||
| } | ||||
|  | ||||
| func TestID(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.Equal(t, "bolt-db", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestProvides(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	require.True(t, h.Provides(mqtt.OnSessionEstablished)) | ||||
| 	require.True(t, h.Provides(mqtt.OnDisconnect)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnUnsubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnRetainMessage)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosPublish)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosComplete)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosDropped)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSysInfoTick)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredClients)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredInflightMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredRetainedMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSubscriptions)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSysInfo)) | ||||
| 	require.False(t, h.Provides(mqtt.OnACLCheck)) | ||||
| 	require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) | ||||
| } | ||||
|  | ||||
| func TestInitBadConfig(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(map[string]any{}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestInitUseDefaults(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	require.Equal(t, defaultTimeout, h.config.Options.Timeout) | ||||
| 	require.Equal(t, defaultDbFile, h.config.Path) | ||||
| } | ||||
|  | ||||
| func TestInitBadPath(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(&Options{ | ||||
| 		Path: "..", | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.One("ID", clientKey(client), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
| 	require.Equal(t, client.Net.Remote, r.Remote) | ||||
| 	require.Equal(t, client.Net.Listener, r.Listener) | ||||
| 	require.Equal(t, client.Properties.Username, r.Username) | ||||
| 	require.Equal(t, client.Properties.Clean, r.Clean) | ||||
| 	require.NotSame(t, client, r) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| 	r2 := new(storage.Client) | ||||
| 	err = h.db.One("ID", clientKey(client), r2) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, true) | ||||
| 	r3 := new(storage.Client) | ||||
| 	err = h.db.One("ID", clientKey(client), r3) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, storm.ErrNotFound, err) | ||||
| 	require.Empty(t, r3.ID) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnWillSent(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	c1 := client | ||||
| 	c1.Properties.Will.Flag = 1 | ||||
| 	h.OnWillSent(c1, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.One("ID", clientKey(client), r) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, uint32(1), r.Will.Flag) | ||||
| 	require.NotSame(t, client, r) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpired(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	cl := &mqtt.Client{ID: "cl1"} | ||||
| 	clientKey := clientKey(cl) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Client{ID: cl.ID}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	err = h.db.One("ID", clientKey, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, cl.ID, r.ID) | ||||
|  | ||||
| 	h.OnClientExpired(cl) | ||||
| 	err = h.db.One("ID", clientKey, r) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| 	r := new(storage.Subscription) | ||||
|  | ||||
| 	err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.Client) | ||||
| 	require.Equal(t, pkf.Filters[0].Filter, r.Filter) | ||||
| 	require.Equal(t, byte(0), r.Qos) | ||||
|  | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| 	err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageThenUnset(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, 1) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.One("ID", retainedKey(pk.TopicName), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	err = h.db.One("ID", retainedKey(pk.TopicName), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
|  | ||||
| 	// coverage: delete deleted | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	err = h.db.One("ID", retainedKey(pk.TopicName), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpired(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	m := &storage.Message{ | ||||
| 		ID:        retainedKey("a/b/c"), | ||||
| 		T:         storage.RetainedKey, | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	err = h.db.Save(m) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.One("ID", m.ID, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, m.TopicName, r.TopicName) | ||||
|  | ||||
| 	h.OnRetainedExpired(m.TopicName) | ||||
| 	err = h.db.One("ID", m.ID, r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishThenQOSComplete(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 			Qos:    2, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosPublish(client, pk, time.Now().Unix(), 0) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	err = h.db.One("ID", inflightKey(client, pk), r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	// ensure dates are properly saved to bolt | ||||
| 	require.True(t, r.Sent > 0) | ||||
| 	require.True(t, time.Now().Unix()-1 < r.Sent) | ||||
|  | ||||
| 	// OnQosDropped is a passthrough to OnQosComplete here | ||||
| 	h.OnQosDropped(client, pk) | ||||
| 	err = h.db.One("ID", inflightKey(client, pk), r) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, storm.ErrNotFound, err) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20}) | ||||
| 	require.NoError(t, err) | ||||
| 	err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var v []storage.Message | ||||
| 	err = h.db.Find("T", storage.InflightKey, &v) | ||||
| 	if err != nil && !errors.Is(err, storm.ErrNotFound) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	require.Len(t, v, 1) | ||||
| 	require.Equal(t, "i1", v[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	info := &system.Info{ | ||||
| 		Version:       "2.0.0", | ||||
| 		BytesReceived: 100, | ||||
| 	} | ||||
|  | ||||
| 	h.OnSysInfoTick(info) | ||||
|  | ||||
| 	r := new(storage.SystemInfo) | ||||
| 	err = h.db.One("ID", storage.SysInfoKey, r) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, info.Version, r.Version) | ||||
| 	require.Equal(t, info.BytesReceived, r.BytesReceived) | ||||
| 	require.NotSame(t, info, r) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTickNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTickClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
|  | ||||
| func TestStoredClients(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with clients | ||||
| 	err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "cl1", r[0].ID) | ||||
| 	require.Equal(t, "cl2", r[1].ID) | ||||
| 	require.Equal(t, "cl3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredClientsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredClientsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptions(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with subscriptions | ||||
| 	err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "sub1", r[0].ID) | ||||
| 	require.Equal(t, "sub2", r[1].ID) | ||||
| 	require.Equal(t, "sub3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptionsNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptionsClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessages(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "m1", r[0].ID) | ||||
| 	require.Equal(t, "m2", r[1].ID) | ||||
| 	require.Equal(t, "m3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessagesNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessagesClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessages(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	require.Equal(t, "i1", r[0].ID) | ||||
| 	require.Equal(t, "i2", r[1].ID) | ||||
| 	require.Equal(t, "i3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessagesNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessagesClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfo(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h.config.Path, h) | ||||
|  | ||||
| 	// populate with sys info | ||||
| 	err = h.db.Save(&storage.SystemInfo{ | ||||
| 		ID: storage.SysInfoKey, | ||||
| 		Info: system.Info{ | ||||
| 			Version: "2.0.0", | ||||
| 		}, | ||||
| 		T: storage.SysInfoKey, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "2.0.0", r.Info.Version) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfoNoDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfoClosedDB(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	teardown(t, h.config.Path, h) | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
							
								
								
									
										529
									
								
								hooks/storage/redis/redis.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										529
									
								
								hooks/storage/redis/redis.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,529 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package redis | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	redis "github.com/go-redis/redis/v8" | ||||
| ) | ||||
|  | ||||
| // defaultAddr is the default address to the redis service. | ||||
| const defaultAddr = "localhost:6379" | ||||
|  | ||||
| // defaultHPrefix is a prefix to better identify hsets created by mochi mqtt. | ||||
| const defaultHPrefix = "mochi-" | ||||
|  | ||||
| // clientKey returns a primary key for a client. | ||||
| func clientKey(cl *mqtt.Client) string { | ||||
| 	return cl.ID | ||||
| } | ||||
|  | ||||
| // subscriptionKey returns a primary key for a subscription. | ||||
| func subscriptionKey(cl *mqtt.Client, filter string) string { | ||||
| 	return cl.ID + ":" + filter | ||||
| } | ||||
|  | ||||
| // retainedKey returns a primary key for a retained message. | ||||
| func retainedKey(topic string) string { | ||||
| 	return topic | ||||
| } | ||||
|  | ||||
| // inflightKey returns a primary key for an inflight message. | ||||
| func inflightKey(cl *mqtt.Client, pk packets.Packet) string { | ||||
| 	return cl.ID + ":" + pk.FormatID() | ||||
| } | ||||
|  | ||||
| // sysInfoKey returns a primary key for system info. | ||||
| func sysInfoKey() string { | ||||
| 	return storage.SysInfoKey | ||||
| } | ||||
|  | ||||
| // Options contains configuration settings for the bolt instance. | ||||
| type Options struct { | ||||
| 	HPrefix string | ||||
| 	Options *redis.Options | ||||
| } | ||||
|  | ||||
| // Hook is a persistent storage hook based using Redis as a backend. | ||||
| type Hook struct { | ||||
| 	mqtt.HookBase | ||||
| 	config *Options        // options for connecting to the Redis instance. | ||||
| 	db     *redis.Client   // the Redis instance | ||||
| 	ctx    context.Context // a context for the connection | ||||
| } | ||||
|  | ||||
| // ID returns the id of the hook. | ||||
| func (h *Hook) ID() string { | ||||
| 	return "redis-db" | ||||
| } | ||||
|  | ||||
| // Provides indicates which hook methods this hook provides. | ||||
| func (h *Hook) Provides(b byte) bool { | ||||
| 	return bytes.Contains([]byte{ | ||||
| 		mqtt.OnSessionEstablished, | ||||
| 		mqtt.OnDisconnect, | ||||
| 		mqtt.OnSubscribed, | ||||
| 		mqtt.OnUnsubscribed, | ||||
| 		mqtt.OnRetainMessage, | ||||
| 		mqtt.OnQosPublish, | ||||
| 		mqtt.OnQosComplete, | ||||
| 		mqtt.OnQosDropped, | ||||
| 		mqtt.OnWillSent, | ||||
| 		mqtt.OnSysInfoTick, | ||||
| 		mqtt.OnClientExpired, | ||||
| 		mqtt.OnRetainedExpired, | ||||
| 		mqtt.OnExpireInflights, | ||||
| 		mqtt.StoredClients, | ||||
| 		mqtt.StoredInflightMessages, | ||||
| 		mqtt.StoredRetainedMessages, | ||||
| 		mqtt.StoredSubscriptions, | ||||
| 		mqtt.StoredSysInfo, | ||||
| 	}, []byte{b}) | ||||
| } | ||||
|  | ||||
| // hKey returns a hash set key with a unique prefix. | ||||
| func (h *Hook) hKey(s string) string { | ||||
| 	return h.config.HPrefix + s | ||||
| } | ||||
|  | ||||
| // Init initializes and connects to the redis service. | ||||
| func (h *Hook) Init(config any) error { | ||||
| 	if _, ok := config.(*Options); !ok && config != nil { | ||||
| 		return mqtt.ErrInvalidConfigType | ||||
| 	} | ||||
|  | ||||
| 	h.ctx = context.Background() | ||||
|  | ||||
| 	if config == nil { | ||||
| 		config = &Options{ | ||||
| 			Options: &redis.Options{ | ||||
| 				Addr: defaultAddr, | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	h.config = config.(*Options) | ||||
| 	if h.config.HPrefix == "" { | ||||
| 		h.config.HPrefix = defaultHPrefix | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Info(). | ||||
| 		Str("address", h.config.Options.Addr). | ||||
| 		Str("username", h.config.Options.Username). | ||||
| 		Int("password-len", len(h.config.Options.Password)). | ||||
| 		Int("db", h.config.Options.DB). | ||||
| 		Msg("connecting to redis service") | ||||
|  | ||||
| 	h.db = redis.NewClient(h.config.Options) | ||||
| 	_, err := h.db.Ping(context.Background()).Result() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to ping service: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	h.Log.Info().Msg("connected to redis service") | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Close closes the redis connection. | ||||
| func (h *Hook) Stop() error { | ||||
| 	h.Log.Info().Msg("disconnecting from redis service") | ||||
| 	return h.db.Close() | ||||
| } | ||||
|  | ||||
| // OnSessionEstablished adds a client to the store when their session is established. | ||||
| func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // OnWillSent is called when a client sends a will message and the will message is removed | ||||
| // from the client record. | ||||
| func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	h.updateClient(cl) | ||||
| } | ||||
|  | ||||
| // updateClient writes the client data to the store. | ||||
| func (h *Hook) updateClient(cl *mqtt.Client) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := cl.Properties.Props.Copy(false) | ||||
| 	in := &storage.Client{ | ||||
| 		ID:              clientKey(cl), | ||||
| 		T:               storage.ClientKey, | ||||
| 		Remote:          cl.Net.Remote, | ||||
| 		Listener:        cl.Net.Listener, | ||||
| 		Username:        cl.Properties.Username, | ||||
| 		Clean:           cl.Properties.Clean, | ||||
| 		ProtocolVersion: cl.Properties.ProtocolVersion, | ||||
| 		Properties: storage.ClientProperties{ | ||||
| 			SessionExpiryInterval: props.SessionExpiryInterval, | ||||
| 			AuthenticationMethod:  props.AuthenticationMethod, | ||||
| 			AuthenticationData:    props.AuthenticationData, | ||||
| 			RequestProblemInfo:    props.RequestProblemInfo, | ||||
| 			RequestResponseInfo:   props.RequestResponseInfo, | ||||
| 			ReceiveMaximum:        props.ReceiveMaximum, | ||||
| 			TopicAliasMaximum:     props.TopicAliasMaximum, | ||||
| 			User:                  props.User, | ||||
| 			MaximumPacketSize:     props.MaximumPacketSize, | ||||
| 		}, | ||||
| 		Will: storage.ClientWill(cl.Properties.Will), | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnDisconnect removes a client from the store if they were using a clean session. | ||||
| func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if !expire { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnSubscribed adds one or more client subscriptions to the store. | ||||
| func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var in *storage.Subscription | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		in = &storage.Subscription{ | ||||
| 			ID:     subscriptionKey(cl, pk.Filters[i].Filter), | ||||
| 			T:      storage.SubscriptionKey, | ||||
| 			Client: cl.ID, | ||||
| 			Filter: pk.Filters[i].Filter, | ||||
| 			Qos:    reasonCodes[i], | ||||
| 		} | ||||
|  | ||||
| 		err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err() | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnUnsubscribed removes one or more client subscriptions from the store. | ||||
| func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < len(pk.Filters); i++ { | ||||
| 		err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err() | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainMessage adds a retained message for a topic to the store. | ||||
| func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if r == -1 { | ||||
| 		err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err() | ||||
| 		if err != nil { | ||||
| 			h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data") | ||||
| 		} | ||||
|  | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          retainedKey(pk.TopicName), | ||||
| 		T:           storage.RetainedKey, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Created:     pk.Created, | ||||
| 		Origin:      pk.Origin, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosPublish adds or updates an inflight message in the store. | ||||
| func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	props := pk.Properties.Copy(false) | ||||
| 	in := &storage.Message{ | ||||
| 		ID:          inflightKey(cl, pk), | ||||
| 		T:           storage.InflightKey, | ||||
| 		Origin:      pk.Origin, | ||||
| 		FixedHeader: pk.FixedHeader, | ||||
| 		TopicName:   pk.TopicName, | ||||
| 		Payload:     pk.Payload, | ||||
| 		Sent:        sent, | ||||
| 		Created:     pk.Created, | ||||
| 		Properties: storage.MessageProperties{ | ||||
| 			PayloadFormat:          props.PayloadFormat, | ||||
| 			MessageExpiryInterval:  props.MessageExpiryInterval, | ||||
| 			ContentType:            props.ContentType, | ||||
| 			ResponseTopic:          props.ResponseTopic, | ||||
| 			CorrelationData:        props.CorrelationData, | ||||
| 			SubscriptionIdentifier: props.SubscriptionIdentifier, | ||||
| 			TopicAlias:             props.TopicAlias, | ||||
| 			User:                   props.User, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosComplete removes a resolved inflight message from the store. | ||||
| func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnQosDropped removes a dropped inflight message from the store. | ||||
| func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosComplete(cl, pk) | ||||
| } | ||||
|  | ||||
| // OnSysInfoTick stores the latest system info in the store. | ||||
| func (h *Hook) OnSysInfoTick(sys *system.Info) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	in := &storage.SystemInfo{ | ||||
| 		ID:   sysInfoKey(), | ||||
| 		T:    storage.SysInfoKey, | ||||
| 		Info: *sys, | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnExpireInflights removes all inflight messages which have passed the | ||||
| // provided expiry time. | ||||
| func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll inflight data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") | ||||
| 		} | ||||
|  | ||||
| 		if d.Created < expiry || d.Created == 0 { | ||||
| 			err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err() | ||||
| 			if err != nil { | ||||
| 				h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnRetainedExpired deletes expired retained messages from the store. | ||||
| func (h *Hook) OnRetainedExpired(filter string) { | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnClientExpired deleted expired clients from the store. | ||||
| func (h *Hook) OnClientExpired(cl *mqtt.Client) { | ||||
| 	err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err() | ||||
| 	if err != nil { | ||||
| 		h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // StoredClients returns all stored clients from the store. | ||||
| func (h *Hook) StoredClients() (v []storage.Client, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll client data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Client | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data") | ||||
| 		} | ||||
|  | ||||
| 		v = append(v, d) | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSubscriptions returns all stored subscriptions from the store. | ||||
| func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll subscription data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Subscription | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data") | ||||
| 		} | ||||
|  | ||||
| 		v = append(v, d) | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredRetainedMessages returns all stored retained messages from the store. | ||||
| func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll retained message data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data") | ||||
| 		} | ||||
|  | ||||
| 		v = append(v, d) | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredInflightMessages returns all stored inflight messages from the store. | ||||
| func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		if err = d.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 			h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data") | ||||
| 		} | ||||
|  | ||||
| 		v = append(v, d) | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // StoredSysInfo returns the system info from the store. | ||||
| func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	if h.db == nil { | ||||
| 		h.Log.Error().Err(storage.ErrDBFileNotOpen) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result() | ||||
| 	if err != nil && !errors.Is(err, redis.Nil) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if err = v.UnmarshalBinary([]byte(row)); err != nil { | ||||
| 		h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data") | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
							
								
								
									
										811
									
								
								hooks/storage/redis/redis_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										811
									
								
								hooks/storage/redis/redis_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,811 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package redis | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"sort" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2" | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	miniredis "github.com/alicebob/miniredis/v2" | ||||
| 	redis "github.com/go-redis/redis/v8" | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) | ||||
|  | ||||
| 	client = &mqtt.Client{ | ||||
| 		ID: "test", | ||||
| 		Net: mqtt.ClientConnection{ | ||||
| 			Remote:   "test.addr", | ||||
| 			Listener: "listener", | ||||
| 		}, | ||||
| 		Properties: mqtt.ClientProperties{ | ||||
| 			Username: []byte("username"), | ||||
| 			Clean:    false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}} | ||||
| ) | ||||
|  | ||||
| func newHook(t *testing.T, addr string) *Hook { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(&Options{ | ||||
| 		Options: &redis.Options{ | ||||
| 			Addr: addr, | ||||
| 		}, | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	return h | ||||
| } | ||||
|  | ||||
| func teardown(t *testing.T, h *Hook) { | ||||
| 	if h.db != nil { | ||||
| 		err := h.db.FlushAll(h.ctx).Err() | ||||
| 		require.NoError(t, err) | ||||
| 		h.Stop() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestClientKey(t *testing.T) { | ||||
| 	k := clientKey(&mqtt.Client{ID: "cl1"}) | ||||
| 	require.Equal(t, "cl1", k) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionKey(t *testing.T) { | ||||
| 	k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c") | ||||
| 	require.Equal(t, "cl1:a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestRetainedKey(t *testing.T) { | ||||
| 	k := retainedKey("a/b/c") | ||||
| 	require.Equal(t, "a/b/c", k) | ||||
| } | ||||
|  | ||||
| func TestInflightKey(t *testing.T) { | ||||
| 	k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1}) | ||||
| 	require.Equal(t, "cl1:1", k) | ||||
| } | ||||
|  | ||||
| func TestSysInfoKey(t *testing.T) { | ||||
| 	require.Equal(t, storage.SysInfoKey, sysInfoKey()) | ||||
| } | ||||
|  | ||||
| func TestID(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	require.Equal(t, "redis-db", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestProvides(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	require.True(t, h.Provides(mqtt.OnSessionEstablished)) | ||||
| 	require.True(t, h.Provides(mqtt.OnDisconnect)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnUnsubscribed)) | ||||
| 	require.True(t, h.Provides(mqtt.OnRetainMessage)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosPublish)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosComplete)) | ||||
| 	require.True(t, h.Provides(mqtt.OnQosDropped)) | ||||
| 	require.True(t, h.Provides(mqtt.OnSysInfoTick)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredClients)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredInflightMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredRetainedMessages)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSubscriptions)) | ||||
| 	require.True(t, h.Provides(mqtt.StoredSysInfo)) | ||||
| 	require.False(t, h.Provides(mqtt.OnACLCheck)) | ||||
| 	require.False(t, h.Provides(mqtt.OnConnectAuthenticate)) | ||||
| } | ||||
|  | ||||
| func TestHKey(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	require.Equal(t, defaultHPrefix+"test", h.hKey("test")) | ||||
| } | ||||
|  | ||||
| func TestInitUseDefaults(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	s.StartAddr(defaultAddr) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	h := newHook(t, defaultAddr) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	require.Equal(t, defaultHPrefix, h.config.HPrefix) | ||||
| 	require.Equal(t, defaultAddr, h.config.Options.Addr) | ||||
| } | ||||
|  | ||||
| func TestInitBadConfig(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
|  | ||||
| 	err := h.Init(map[string]any{}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestInitBadAddr(t *testing.T) { | ||||
| 	h := new(Hook) | ||||
| 	h.SetOpts(&logger, nil) | ||||
| 	err := h.Init(&Options{ | ||||
| 		Options: &redis.Options{ | ||||
| 			Addr: "abc:123", | ||||
| 		}, | ||||
| 	}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
| 	require.Equal(t, client.Net.Remote, r.Remote) | ||||
| 	require.Equal(t, client.Net.Listener, r.Listener) | ||||
| 	require.Equal(t, client.Properties.Username, r.Username) | ||||
| 	require.Equal(t, client.Properties.Clean, r.Clean) | ||||
| 	require.NotSame(t, client, r) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| 	r2 := new(storage.Client) | ||||
| 	row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r2.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.ID) | ||||
|  | ||||
| 	h.OnDisconnect(client, nil, true) | ||||
| 	r3 := new(storage.Client) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| 	require.Empty(t, r3.ID) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
|  | ||||
| 	h.db = nil | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnSessionEstablishedClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnSessionEstablished(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnWillSent(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	c1 := client | ||||
| 	c1.Properties.Will.Flag = 1 | ||||
| 	h.OnWillSent(c1, packets.Packet{}) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.Equal(t, uint32(1), r.Will.Flag) | ||||
| 	require.NotSame(t, client, r) | ||||
| } | ||||
|  | ||||
| func TestOnClientExpired(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	cl := &mqtt.Client{ID: "cl1"} | ||||
| 	clientKey := clientKey(cl) | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Client) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, clientKey, r.ID) | ||||
|  | ||||
| 	h.OnClientExpired(cl) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, redis.Nil, err) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnDisconnectClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnDisconnect(client, nil, false) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedThenOnUnsubscribed(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
|  | ||||
| 	r := new(storage.Subscription) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, client.ID, r.Client) | ||||
| 	require.Equal(t, pkf.Filters[0].Filter, r.Filter) | ||||
| 	require.Equal(t, byte(0), r.Qos) | ||||
|  | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnSubscribedClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnSubscribed(client, pkf, []byte{0}) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnUnsubscribedClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnUnsubscribed(client, pkf) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageThenUnset(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, 1) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
|  | ||||
| 	// coverage: delete deleted | ||||
| 	h.OnRetainMessage(client, pk, -1) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| } | ||||
|  | ||||
| func TestOnRetainedExpired(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	m := &storage.Message{ | ||||
| 		ID:        retainedKey("a/b/c"), | ||||
| 		T:         storage.RetainedKey, | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, m.TopicName, r.TopicName) | ||||
|  | ||||
| 	h.OnRetainedExpired(m.TopicName) | ||||
|  | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnRetainMessageClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnRetainMessage(client, packets.Packet{}, 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishThenQOSComplete(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 			Qos:    2, | ||||
| 		}, | ||||
| 		Payload:   []byte("hello"), | ||||
| 		TopicName: "a/b/c", | ||||
| 	} | ||||
|  | ||||
| 	h.OnQosPublish(client, pk, time.Now().Unix(), 0) | ||||
|  | ||||
| 	r := new(storage.Message) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, pk.TopicName, r.TopicName) | ||||
| 	require.Equal(t, pk.Payload, r.Payload) | ||||
|  | ||||
| 	// ensure dates are properly saved to bolt | ||||
| 	require.True(t, r.Sent > 0) | ||||
| 	require.True(t, time.Now().Unix()-1 < r.Sent) | ||||
|  | ||||
| 	// OnQosDropped is a passthrough to OnQosComplete here | ||||
| 	h.OnQosDropped(client, pk) | ||||
| 	_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result() | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, redis.Nil) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosPublishClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosCompleteClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnQosComplete(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnQosDroppedNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnQosDropped(client, packets.Packet{}) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflights(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	n := time.Now().Unix() | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", | ||||
| 		&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", | ||||
| 		&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", | ||||
| 		&storage.Message{ID: "i3", T: storage.InflightKey}, | ||||
| 	).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
|  | ||||
| 	var r []storage.Message | ||||
| 	rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, rows, 1) | ||||
| 	for _, row := range rows { | ||||
| 		var d storage.Message | ||||
| 		err = d.UnmarshalBinary([]byte(row)) | ||||
| 		require.NoError(t, err) | ||||
| 		r = append(r, d) | ||||
| 	} | ||||
| 	require.Len(t, r, 1) | ||||
| 	require.Equal(t, "i1", r[0].ID) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnExpireInflightsNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnExpireInflights(client, time.Now().Unix()-10) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTick(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	info := &system.Info{ | ||||
| 		Version:       "2.0.0", | ||||
| 		BytesReceived: 100, | ||||
| 	} | ||||
|  | ||||
| 	h.OnSysInfoTick(info) | ||||
|  | ||||
| 	r := new(storage.SystemInfo) | ||||
| 	row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result() | ||||
| 	require.NoError(t, err) | ||||
| 	err = r.UnmarshalBinary([]byte(row)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, info.Version, r.Version) | ||||
| 	require.Equal(t, info.BytesReceived, r.BytesReceived) | ||||
| 	require.NotSame(t, info, r) | ||||
| } | ||||
|  | ||||
| func TestOnSysInfoTickClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
| func TestOnSysInfoTickNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	h.OnSysInfoTick(new(system.Info)) | ||||
| } | ||||
|  | ||||
| func TestStoredClients(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	// populate with clients | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
|  | ||||
| 	sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) | ||||
| 	require.Equal(t, "cl1", r[0].ID) | ||||
| 	require.Equal(t, "cl2", r[1].ID) | ||||
| 	require.Equal(t, "cl3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredClientsNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredClientsClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
|  | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptions(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	// populate with subscriptions | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) | ||||
| 	require.Equal(t, "sub1", r[0].ID) | ||||
| 	require.Equal(t, "sub2", r[1].ID) | ||||
| 	require.Equal(t, "sub3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptionsNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSubscriptionsClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
|  | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessages(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) | ||||
| 	require.Equal(t, "m1", r[0].ID) | ||||
| 	require.Equal(t, "m2", r[1].ID) | ||||
| 	require.Equal(t, "m3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessagesNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredRetainedMessagesClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
|  | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessages(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	// populate with messages | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, r, 3) | ||||
| 	sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID }) | ||||
| 	require.Equal(t, "i1", r[0].ID) | ||||
| 	require.Equal(t, "i2", r[1].ID) | ||||
| 	require.Equal(t, "i3", r[2].ID) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessagesNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredInflightMessagesClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
|  | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfo(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	defer teardown(t, h) | ||||
|  | ||||
| 	// populate with sys info | ||||
| 	err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey, | ||||
| 		&storage.SystemInfo{ | ||||
| 			ID: storage.SysInfoKey, | ||||
| 			Info: system.Info{ | ||||
| 				Version: "2.0.0", | ||||
| 			}, | ||||
| 			T: storage.SysInfoKey, | ||||
| 		}).Err() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	r, err := h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "2.0.0", r.Info.Version) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfoNoDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	h.db = nil | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.Empty(t, v) | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestStoredSysInfoClosedDB(t *testing.T) { | ||||
| 	s := miniredis.RunT(t) | ||||
| 	defer s.Close() | ||||
| 	h := newHook(t, s.Addr()) | ||||
| 	teardown(t, h) | ||||
|  | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.Empty(t, v) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
							
								
								
									
										164
									
								
								hooks/storage/storage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								hooks/storage/storage.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,164 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package storage | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store | ||||
| 	SysInfoKey      = "SYS" // unique key to denote server system information in a store | ||||
| 	RetainedKey     = "RET" // unique key to denote retained messages in a store | ||||
| 	InflightKey     = "IFM" // unique key to denote inflight messages in a store | ||||
| 	ClientKey       = "CL"  // unique key to denote clients in a store | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading. | ||||
| 	ErrDBFileNotOpen = errors.New("db file not open") | ||||
| ) | ||||
|  | ||||
| // Client is a storable representation of an mqtt client. | ||||
| type Client struct { | ||||
| 	Will            ClientWill       `json:"will"`            // will topic and payload data if applicable | ||||
| 	Properties      ClientProperties `json:"properties"`      // the connect properties for the client | ||||
| 	Username        []byte           `json:"username"`        // the username of the client | ||||
| 	ID              string           `json:"id" storm:"id"`   // the client id / storage key | ||||
| 	T               string           `json:"t"`               // the data type (client) | ||||
| 	Remote          string           `json:"remote"`          // the remote address of the client | ||||
| 	Listener        string           `json:"listener"`        // the listener the client connected on | ||||
| 	ProtocolVersion byte             `json:"protocolVersion"` // mqtt protocol version of the client | ||||
| 	Clean           bool             `json:"clean"`           // if the client requested a clean start/session | ||||
| } | ||||
|  | ||||
| // ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection. | ||||
| type ClientProperties struct { | ||||
| 	AuthenticationData        []byte                 `json:"authenticationData"` | ||||
| 	User                      []packets.UserProperty `json:"user"` | ||||
| 	AuthenticationMethod      string                 `json:"authenticationMethod"` | ||||
| 	SessionExpiryInterval     uint32                 `json:"sessionExpiryInterval"` | ||||
| 	MaximumPacketSize         uint32                 `json:"maximumPacketSize"` | ||||
| 	ReceiveMaximum            uint16                 `json:"receiveMaximum"` | ||||
| 	TopicAliasMaximum         uint16                 `json:"topicAliasMaximum"` | ||||
| 	SessionExpiryIntervalFlag bool                   `json:"sessionExpiryIntervalFlag"` | ||||
| 	RequestProblemInfo        byte                   `json:"requestProblemInfo"` | ||||
| 	RequestProblemInfoFlag    bool                   `json:"requestProblemInfoFlag"` | ||||
| 	RequestResponseInfo       byte                   `json:"requestResponseInfo"` | ||||
| } | ||||
|  | ||||
| // ClientWill contains a will message for a client, and limited mqtt v5 properties. | ||||
| type ClientWill struct { | ||||
| 	Payload           []byte                 `json:"payload"` | ||||
| 	User              []packets.UserProperty `json:"user"` | ||||
| 	TopicName         string                 `json:"topicName"` | ||||
| 	Flag              uint32                 `json:"flag"` | ||||
| 	WillDelayInterval uint32                 `json:"willDelayInterval"` | ||||
| 	Qos               byte                   `json:"qos"` | ||||
| 	Retain            bool                   `json:"retain"` | ||||
| } | ||||
|  | ||||
| // MarshalBinary encodes the values into a json string. | ||||
| func (d Client) MarshalBinary() (data []byte, err error) { | ||||
| 	return json.Marshal(d) | ||||
| } | ||||
|  | ||||
| // UnmarshalBinary decodes a json string into a struct. | ||||
| func (d *Client) UnmarshalBinary(data []byte) error { | ||||
| 	if len(data) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return json.Unmarshal(data, d) | ||||
| } | ||||
|  | ||||
| // Message is a storable representation of an MQTT message (specifically publish). | ||||
| type Message struct { | ||||
| 	Properties  MessageProperties   `json:"properties"`    // - | ||||
| 	Payload     []byte              `json:"payload"`       // the message payload (if retained) | ||||
| 	T           string              `json:"t"`             // the data type | ||||
| 	ID          string              `json:"id" storm:"id"` // the storage key | ||||
| 	Origin      string              `json:"origin"`        // the id of the client who sent the message | ||||
| 	TopicName   string              `json:"topic_name"`    // the topic the message was sent to (if retained) | ||||
| 	FixedHeader packets.FixedHeader `json:"fixedheader"`   // the header properties of the message | ||||
| 	Created     int64               `json:"created"`       // the time the message was created in unixtime | ||||
| 	Sent        int64               `json:"sent"`          // the last time the message was sent (for retries) in unixtime (if inflight) | ||||
| 	PacketID    uint16              `json:"packet_id"`     // the unique id of the packet (if inflight) | ||||
| } | ||||
|  | ||||
| // MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages. | ||||
| type MessageProperties struct { | ||||
| 	CorrelationData        []byte                 `json:"correlationData"` | ||||
| 	SubscriptionIdentifier []int                  `json:"subscriptionIdentifier"` | ||||
| 	User                   []packets.UserProperty `json:"user"` | ||||
| 	ContentType            string                 `json:"contentType"` | ||||
| 	ResponseTopic          string                 `json:"responseTopic"` | ||||
| 	MessageExpiryInterval  uint32                 `json:"messageExpiry"` | ||||
| 	TopicAlias             uint16                 `json:"topicAlias"` | ||||
| 	PayloadFormat          byte                   `json:"payloadFormat"` | ||||
| 	PayloadFormatFlag      bool                   `json:"payloadFormatFlag"` | ||||
| } | ||||
|  | ||||
| // MarshalBinary encodes the values into a json string. | ||||
| func (d Message) MarshalBinary() (data []byte, err error) { | ||||
| 	return json.Marshal(d) | ||||
| } | ||||
|  | ||||
| // UnmarshalBinary decodes a json string into a struct. | ||||
| func (d *Message) UnmarshalBinary(data []byte) error { | ||||
| 	if len(data) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return json.Unmarshal(data, d) | ||||
| } | ||||
|  | ||||
| // Subscription is a storable representation of an mqtt subscription. | ||||
| type Subscription struct { | ||||
| 	T                 string `json:"t"` | ||||
| 	ID                string `json:"id" storm:"id"` | ||||
| 	Client            string `json:"client"` | ||||
| 	Filter            string `json:"filter"` | ||||
| 	Identifier        int    `json:"identifier"` | ||||
| 	RetainHandling    byte   `json:"retain_handling"` | ||||
| 	Qos               byte   `json:"qos"` | ||||
| 	RetainAsPublished bool   `json:"retain_as_pub"` | ||||
| 	NoLocal           bool   `json:"no_local"` | ||||
| } | ||||
|  | ||||
| // MarshalBinary encodes the values into a json string. | ||||
| func (d Subscription) MarshalBinary() (data []byte, err error) { | ||||
| 	return json.Marshal(d) | ||||
| } | ||||
|  | ||||
| // UnmarshalBinary decodes a json string into a struct. | ||||
| func (d *Subscription) UnmarshalBinary(data []byte) error { | ||||
| 	if len(data) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return json.Unmarshal(data, d) | ||||
| } | ||||
|  | ||||
| // SystemInfo is a storable representation of the system information values. | ||||
| type SystemInfo struct { | ||||
| 	system.Info        // embed the system info struct | ||||
| 	T           string `json:"t"`             // the data type | ||||
| 	ID          string `json:"id" storm:"id"` // the storage key | ||||
| } | ||||
|  | ||||
| // MarshalBinary encodes the values into a json string. | ||||
| func (d SystemInfo) MarshalBinary() (data []byte, err error) { | ||||
| 	return json.Marshal(d) | ||||
| } | ||||
|  | ||||
| // UnmarshalBinary decodes a json string into a struct. | ||||
| func (d *SystemInfo) UnmarshalBinary(data []byte) error { | ||||
| 	if len(data) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return json.Unmarshal(data, d) | ||||
| } | ||||
							
								
								
									
										195
									
								
								hooks/storage/storage_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								hooks/storage/storage_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,195 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package storage | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	clientStruct = Client{ | ||||
| 		ID:       "test", | ||||
| 		T:        "client", | ||||
| 		Remote:   "remote", | ||||
| 		Listener: "listener", | ||||
| 		Username: []byte("mochi"), | ||||
| 		Clean:    true, | ||||
| 		Properties: ClientProperties{ | ||||
| 			SessionExpiryInterval:     2, | ||||
| 			SessionExpiryIntervalFlag: true, | ||||
| 			AuthenticationMethod:      "a", | ||||
| 			AuthenticationData:        []byte("test"), | ||||
| 			RequestProblemInfo:        1, | ||||
| 			RequestProblemInfoFlag:    true, | ||||
| 			RequestResponseInfo:       1, | ||||
| 			ReceiveMaximum:            128, | ||||
| 			TopicAliasMaximum:         256, | ||||
| 			User: []packets.UserProperty{ | ||||
| 				{Key: "k", Val: "v"}, | ||||
| 			}, | ||||
| 			MaximumPacketSize: 120, | ||||
| 		}, | ||||
| 		Will: ClientWill{ | ||||
| 			Qos:               1, | ||||
| 			Payload:           []byte("abc"), | ||||
| 			TopicName:         "a/b/c", | ||||
| 			Flag:              1, | ||||
| 			Retain:            true, | ||||
| 			WillDelayInterval: 2, | ||||
| 			User: []packets.UserProperty{ | ||||
| 				{Key: "k2", Val: "v2"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`) | ||||
|  | ||||
| 	messageStruct = Message{ | ||||
| 		T:       "message", | ||||
| 		Payload: []byte("payload"), | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Remaining: 2, | ||||
| 			Type:      3, | ||||
| 			Qos:       1, | ||||
| 			Dup:       true, | ||||
| 			Retain:    true, | ||||
| 		}, | ||||
| 		ID:        "id", | ||||
| 		Origin:    "mochi", | ||||
| 		TopicName: "topic", | ||||
| 		Properties: MessageProperties{ | ||||
| 			PayloadFormat:          1, | ||||
| 			PayloadFormatFlag:      true, | ||||
| 			MessageExpiryInterval:  20, | ||||
| 			ContentType:            "type", | ||||
| 			ResponseTopic:          "a/b/r", | ||||
| 			CorrelationData:        []byte("r"), | ||||
| 			SubscriptionIdentifier: []int{1}, | ||||
| 			TopicAlias:             2, | ||||
| 			User: []packets.UserProperty{ | ||||
| 				{Key: "k2", Val: "v2"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Created:  time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), | ||||
| 		Sent:     time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), | ||||
| 		PacketID: 100, | ||||
| 	} | ||||
| 	messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`) | ||||
|  | ||||
| 	subscriptionStruct = Subscription{ | ||||
| 		T:      "subscription", | ||||
| 		ID:     "id", | ||||
| 		Client: "mochi", | ||||
| 		Filter: "a/b/c", | ||||
| 		Qos:    1, | ||||
| 	} | ||||
| 	subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`) | ||||
|  | ||||
| 	sysInfoStruct = SystemInfo{ | ||||
| 		T:  "info", | ||||
| 		ID: "id", | ||||
| 		Info: system.Info{ | ||||
| 			Version:          "2.0.0", | ||||
| 			Started:          1, | ||||
| 			Uptime:           2, | ||||
| 			BytesReceived:    3, | ||||
| 			BytesSent:        4, | ||||
| 			ClientsConnected: 5, | ||||
| 			ClientsMaximum:   7, | ||||
| 			MessagesReceived: 10, | ||||
| 			MessagesSent:     11, | ||||
| 			PacketsReceived:  12, | ||||
| 			PacketsSent:      13, | ||||
| 			Retained:         15, | ||||
| 			Inflight:         16, | ||||
| 			InflightDropped:  17, | ||||
| 		}, | ||||
| 	} | ||||
| 	sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`) | ||||
| ) | ||||
|  | ||||
| func TestClientMarshalBinary(t *testing.T) { | ||||
| 	data, err := clientStruct.MarshalBinary() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, clientJSON, data) | ||||
| } | ||||
|  | ||||
| func TestClientUnmarshalBinary(t *testing.T) { | ||||
| 	d := clientStruct | ||||
| 	err := d.UnmarshalBinary(clientJSON) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, clientStruct, d) | ||||
| } | ||||
|  | ||||
| func TestClientUnmarshalBinaryEmpty(t *testing.T) { | ||||
| 	d := Client{} | ||||
| 	err := d.UnmarshalBinary([]byte{}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, Client{}, d) | ||||
| } | ||||
|  | ||||
| func TestMessageMarshalBinary(t *testing.T) { | ||||
| 	data, err := messageStruct.MarshalBinary() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, messageJSON, data) | ||||
| } | ||||
|  | ||||
| func TestMessageUnmarshalBinary(t *testing.T) { | ||||
| 	d := messageStruct | ||||
| 	err := d.UnmarshalBinary(messageJSON) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, messageStruct, d) | ||||
| } | ||||
|  | ||||
| func TestMessageUnmarshalBinaryEmpty(t *testing.T) { | ||||
| 	d := Message{} | ||||
| 	err := d.UnmarshalBinary([]byte{}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, Message{}, d) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionMarshalBinary(t *testing.T) { | ||||
| 	data, err := subscriptionStruct.MarshalBinary() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, subscriptionJSON, data) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionUnmarshalBinary(t *testing.T) { | ||||
| 	d := subscriptionStruct | ||||
| 	err := d.UnmarshalBinary(subscriptionJSON) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, subscriptionStruct, d) | ||||
| } | ||||
|  | ||||
| func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) { | ||||
| 	d := Subscription{} | ||||
| 	err := d.UnmarshalBinary([]byte{}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, Subscription{}, d) | ||||
| } | ||||
|  | ||||
| func TestSysInfoMarshalBinary(t *testing.T) { | ||||
| 	data, err := sysInfoStruct.MarshalBinary() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, sysInfoJSON, data) | ||||
| } | ||||
|  | ||||
| func TestSysInfoUnmarshalBinary(t *testing.T) { | ||||
| 	d := sysInfoStruct | ||||
| 	err := d.UnmarshalBinary(sysInfoJSON) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, sysInfoStruct, d) | ||||
| } | ||||
|  | ||||
| func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) { | ||||
| 	d := SystemInfo{} | ||||
| 	err := d.UnmarshalBinary([]byte{}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, SystemInfo{}, d) | ||||
| } | ||||
							
								
								
									
										622
									
								
								hooks_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										622
									
								
								hooks_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,622 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"strconv" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/hooks/storage" | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| type modifiedHookBase struct { | ||||
| 	HookBase | ||||
| 	err    error | ||||
| 	fail   bool | ||||
| 	failAt int | ||||
| } | ||||
|  | ||||
| var errTestHook = errors.New("error") | ||||
|  | ||||
| func (h *modifiedHookBase) Init(config any) error { | ||||
| 	if config != nil { | ||||
| 		return errTestHook | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) Provides(b byte) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) Stop() error { | ||||
| 	if h.fail { | ||||
| 		return errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	if h.fail { | ||||
| 		if h.err != nil { | ||||
| 			return pk, h.err | ||||
| 		} | ||||
|  | ||||
| 		return pk, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	if h.fail { | ||||
| 		if h.err != nil { | ||||
| 			return pk, h.err | ||||
| 		} | ||||
|  | ||||
| 		return pk, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) { | ||||
| 	if h.fail { | ||||
| 		if h.err != nil { | ||||
| 			return pk, h.err | ||||
| 		} | ||||
|  | ||||
| 		return pk, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return pk, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) { | ||||
| 	if h.fail { | ||||
| 		return will, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return will, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) { | ||||
| 	if h.fail || h.failAt == 1 { | ||||
| 		return v, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return []storage.Client{ | ||||
| 		{ID: "cl1"}, | ||||
| 		{ID: "cl2"}, | ||||
| 		{ID: "cl3"}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) { | ||||
| 	if h.fail || h.failAt == 2 { | ||||
| 		return v, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return []storage.Subscription{ | ||||
| 		{ID: "sub1"}, | ||||
| 		{ID: "sub2"}, | ||||
| 		{ID: "sub3"}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) { | ||||
| 	if h.fail || h.failAt == 3 { | ||||
| 		return v, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return []storage.Message{ | ||||
| 		{ID: "r1"}, | ||||
| 		{ID: "r2"}, | ||||
| 		{ID: "r3"}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) { | ||||
| 	if h.fail || h.failAt == 4 { | ||||
| 		return v, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return []storage.Message{ | ||||
| 		{ID: "i1"}, | ||||
| 		{ID: "i2"}, | ||||
| 		{ID: "i3"}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) { | ||||
| 	if h.fail || h.failAt == 5 { | ||||
| 		return v, errTestHook | ||||
| 	} | ||||
|  | ||||
| 	return storage.SystemInfo{ | ||||
| 		Info: system.Info{ | ||||
| 			Version: "2.0.0", | ||||
| 		}, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| type providesCheckHook struct { | ||||
| 	HookBase | ||||
| } | ||||
|  | ||||
| func (h *providesCheckHook) Provides(b byte) bool { | ||||
| 	return b == OnConnect | ||||
| } | ||||
|  | ||||
| func TestHooksProvides(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(providesCheckHook), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = h.Add(new(HookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.True(t, h.Provides(OnConnect, OnDisconnect)) | ||||
| 	require.False(t, h.Provides(OnDisconnect)) | ||||
| } | ||||
|  | ||||
| func TestHooksAddAndLen(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(HookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(1), atomic.LoadInt64(&h.qty)) | ||||
| 	require.Equal(t, int64(1), h.Len()) | ||||
| } | ||||
|  | ||||
| func TestHooksAddInitFailure(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(modifiedHookBase), map[string]any{}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, int64(0), atomic.LoadInt64(&h.qty)) | ||||
| } | ||||
|  | ||||
| func TestHooksStop(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	err := h.Add(new(HookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(1), atomic.LoadInt64(&h.qty)) | ||||
| 	require.Equal(t, int64(1), h.Len()) | ||||
|  | ||||
| 	h.Stop() | ||||
| } | ||||
|  | ||||
| // coverage: also cover some empty functions | ||||
| func TestHooksNonReturns(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	cl := new(Client) | ||||
|  | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		t.Run("step-"+strconv.Itoa(i), func(t *testing.T) { | ||||
| 			// on first iteration, check without hook methods | ||||
| 			h.OnStarted() | ||||
| 			h.OnStopped() | ||||
| 			h.OnSysInfoTick(new(system.Info)) | ||||
| 			h.OnConnect(cl, packets.Packet{}) | ||||
| 			h.OnSessionEstablished(cl, packets.Packet{}) | ||||
| 			h.OnDisconnect(cl, nil, false) | ||||
| 			h.OnPacketSent(cl, packets.Packet{}, []byte{}) | ||||
| 			h.OnPacketProcessed(cl, packets.Packet{}, nil) | ||||
| 			h.OnSubscribed(cl, packets.Packet{}, []byte{1}) | ||||
| 			h.OnUnsubscribed(cl, packets.Packet{}) | ||||
| 			h.OnPublished(cl, packets.Packet{}) | ||||
| 			h.OnRetainMessage(cl, packets.Packet{}, 0) | ||||
| 			h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0) | ||||
| 			h.OnQosComplete(cl, packets.Packet{}) | ||||
| 			h.OnQosDropped(cl, packets.Packet{}) | ||||
| 			h.OnWillSent(cl, packets.Packet{}) | ||||
| 			h.OnClientExpired(cl) | ||||
| 			h.OnRetainedExpired("a/b/c") | ||||
| 			h.OnExpireInflights(cl, time.Now().Unix()-1) | ||||
|  | ||||
| 			// on second iteration, check added hook methods | ||||
| 			err := h.Add(new(modifiedHookBase), nil) | ||||
| 			require.NoError(t, err) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHooksOnConnectAuthenticate(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
|  | ||||
| 	ok := h.OnConnectAuthenticate(new(Client), packets.Packet{}) | ||||
| 	require.False(t, ok) | ||||
|  | ||||
| 	err := h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	ok = h.OnConnectAuthenticate(new(Client), packets.Packet{}) | ||||
| 	require.True(t, ok) | ||||
| } | ||||
|  | ||||
| func TestHooksOnACLCheck(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
|  | ||||
| 	ok := h.OnACLCheck(new(Client), "a/b/c", true) | ||||
| 	require.False(t, ok) | ||||
|  | ||||
| 	err := h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	ok = h.OnACLCheck(new(Client), "a/b/c", true) | ||||
| 	require.True(t, ok) | ||||
| } | ||||
|  | ||||
| func TestHooksOnSubscribe(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pki := packets.Packet{ | ||||
| 		Filters: packets.Subscriptions{ | ||||
| 			{Filter: "a/b/c", Qos: 1}, | ||||
| 		}, | ||||
| 	} | ||||
| 	pk := h.OnSubscribe(new(Client), pki) | ||||
| 	require.EqualValues(t, pk, pki) | ||||
| } | ||||
|  | ||||
| func TestHooksOnSelectSubscribers(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	subs := &Subscribers{ | ||||
| 		Subscriptions: map[string]packets.Subscription{ | ||||
| 			"cl1": {Filter: "a/b/c"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	subs2 := h.OnSelectSubscribers(subs, packets.Packet{}) | ||||
| 	require.EqualValues(t, subs, subs2) | ||||
| } | ||||
|  | ||||
| func TestHooksOnUnsubscribe(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	err := h.Add(new(modifiedHookBase), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pki := packets.Packet{ | ||||
| 		Filters: packets.Subscriptions{ | ||||
| 			{Filter: "a/b/c", Qos: 1}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	pk := h.OnUnsubscribe(new(Client), pki) | ||||
| 	require.EqualValues(t, pk, pki) | ||||
| } | ||||
|  | ||||
| func TestHooksOnPublish(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	// coverage: failure | ||||
| 	hook.fail = true | ||||
| 	pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	// coverage: reject packet | ||||
| 	hook.err = packets.ErrRejectPacket | ||||
| 	pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrRejectPacket) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHooksOnPacketRead(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	// coverage: failure | ||||
| 	hook.fail = true | ||||
| 	pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	// coverage: reject packet | ||||
| 	hook.err = packets.ErrRejectPacket | ||||
| 	pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, packets.ErrRejectPacket) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHooksOnAuthPacket(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHooksOnPacketEncode(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHooksOnLWT(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err := h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"}) | ||||
| 	require.Equal(t, "a/b/c", lwt.TopicName) | ||||
|  | ||||
| 	// coverage: fail lwt | ||||
| 	hook.fail = true | ||||
| 	lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"}) | ||||
| 	require.Equal(t, "a/b/c", lwt.TopicName) | ||||
| } | ||||
|  | ||||
| func TestHooksStoredClients(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err = h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v, err = h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 3) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	v, err = h.StoredClients() | ||||
| 	require.Error(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
| } | ||||
|  | ||||
| func TestHooksStoredSubscriptions(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err = h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v, err = h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 3) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	v, err = h.StoredSubscriptions() | ||||
| 	require.Error(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
| } | ||||
|  | ||||
| func TestHooksStoredRetainedMessages(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err = h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v, err = h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 3) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	v, err = h.StoredRetainedMessages() | ||||
| 	require.Error(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
| } | ||||
|  | ||||
| func TestHooksStoredInflightMessages(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err = h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v, err = h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Len(t, v, 3) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	v, err = h.StoredInflightMessages() | ||||
| 	require.Error(t, err) | ||||
| 	require.Len(t, v, 0) | ||||
| } | ||||
|  | ||||
| func TestHooksStoredSysInfo(t *testing.T) { | ||||
| 	h := new(Hooks) | ||||
| 	h.Log = &logger | ||||
|  | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "", v.Info.Version) | ||||
|  | ||||
| 	hook := new(modifiedHookBase) | ||||
| 	err = h.Add(hook, nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v, err = h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "2.0.0", v.Info.Version) | ||||
|  | ||||
| 	hook.fail = true | ||||
| 	v, err = h.StoredSysInfo() | ||||
| 	require.Error(t, err) | ||||
| 	require.Equal(t, "", v.Info.Version) | ||||
| } | ||||
|  | ||||
| func TestHookBaseID(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	require.Equal(t, "base", h.ID()) | ||||
| } | ||||
|  | ||||
| func TestHookBaseProvidesNone(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	require.False(t, h.Provides(OnConnect)) | ||||
| 	require.False(t, h.Provides(OnDisconnect)) | ||||
| } | ||||
|  | ||||
| func TestHookBaseInit(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	require.Nil(t, h.Init(nil)) | ||||
| } | ||||
|  | ||||
| func TestHookBaseSetOpts(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	h.SetOpts(&logger, new(HookOptions)) | ||||
| 	require.NotNil(t, h.Log) | ||||
| 	require.NotNil(t, h.Opts) | ||||
| } | ||||
|  | ||||
| func TestHookBaseClose(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	require.Nil(t, h.Stop()) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnConnectAuthenticate(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v := h.OnConnectAuthenticate(new(Client), packets.Packet{}) | ||||
| 	require.False(t, v) | ||||
| } | ||||
| func TestHookBaseOnACLCheck(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v := h.OnACLCheck(new(Client), "topic", true) | ||||
| 	require.False(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnPublish(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnPacketRead(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnAuthPacket(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, uint16(10), pk.PacketID) | ||||
| } | ||||
|  | ||||
| func TestHookBaseOnLWT(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"}) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "a/b/c", lwt.TopicName) | ||||
| } | ||||
|  | ||||
| func TestHookBaseStoredClients(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v, err := h.StoredClients() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Empty(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseStoredSubscriptions(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v, err := h.StoredSubscriptions() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Empty(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseStoredInflightMessages(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v, err := h.StoredInflightMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Empty(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseStoredRetainedMessages(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v, err := h.StoredRetainedMessages() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Empty(t, v) | ||||
| } | ||||
|  | ||||
| func TestHookBaseStoreSysInfo(t *testing.T) { | ||||
| 	h := new(HookBase) | ||||
| 	v, err := h.StoredSysInfo() | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "", v.Version) | ||||
| } | ||||
							
								
								
									
										144
									
								
								inflight.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								inflight.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 J. Blake / mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"sort" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| ) | ||||
|  | ||||
| // Inflight is a map of InflightMessage keyed on packet id. | ||||
| type Inflight struct { | ||||
| 	sync.RWMutex | ||||
| 	internal            map[uint16]packets.Packet // internal contains the inflight packets | ||||
| 	receiveQuota        int32                     // remaining inbound qos quota for flow control | ||||
| 	sendQuota           int32                     // remaining outbound qos quota for flow control | ||||
| 	maximumReceiveQuota int32                     // maximum allowed receive quota | ||||
| 	maximumSendQuota    int32                     // maximum allowed send quota | ||||
| } | ||||
|  | ||||
| // NewInflights returns a new instance of an Inflight packets map. | ||||
| func NewInflights() *Inflight { | ||||
| 	return &Inflight{ | ||||
| 		internal: map[uint16]packets.Packet{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Set adds or updates an inflight packet by packet id. | ||||
| func (i *Inflight) Set(m packets.Packet) bool { | ||||
| 	i.Lock() | ||||
| 	defer i.Unlock() | ||||
|  | ||||
| 	_, ok := i.internal[m.PacketID] | ||||
| 	i.internal[m.PacketID] = m | ||||
| 	return !ok | ||||
| } | ||||
|  | ||||
| // Get returns an inflight packet by packet id. | ||||
| func (i *Inflight) Get(id uint16) (packets.Packet, bool) { | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
|  | ||||
| 	if m, ok := i.internal[id]; ok { | ||||
| 		return m, true | ||||
| 	} | ||||
|  | ||||
| 	return packets.Packet{}, false | ||||
| } | ||||
|  | ||||
| // Len returns the size of the inflight messages map. | ||||
| func (i *Inflight) Len() int { | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
| 	return len(i.internal) | ||||
| } | ||||
|  | ||||
| // GetAll returns all the inflight messages. | ||||
| func (i *Inflight) GetAll(immediate bool) []packets.Packet { | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
|  | ||||
| 	m := []packets.Packet{} | ||||
| 	for _, v := range i.internal { | ||||
| 		if !immediate || (immediate && v.Expiry < 0) { | ||||
| 			m = append(m, v) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	sort.Slice(m, func(i, j int) bool { | ||||
| 		return uint16(m[i].Created) < uint16(m[j].Created) | ||||
| 	}) | ||||
|  | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| // NextImmediate returns the next inflight packet which is indicated to be sent immediately. | ||||
| // This typically occurs when the quota has been exhausted, and we need to wait until new quota | ||||
| // is free to continue sending. | ||||
| func (i *Inflight) NextImmediate() (packets.Packet, bool) { | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
|  | ||||
| 	m := i.GetAll(true) | ||||
| 	if len(m) > 0 { | ||||
| 		return m[0], true | ||||
| 	} | ||||
|  | ||||
| 	return packets.Packet{}, false | ||||
| } | ||||
|  | ||||
| // Delete removes an in-flight message from the map. Returns true if the message existed. | ||||
| func (i *Inflight) Delete(id uint16) bool { | ||||
| 	i.Lock() | ||||
| 	defer i.Unlock() | ||||
|  | ||||
| 	_, ok := i.internal[id] | ||||
| 	delete(i.internal, id) | ||||
|  | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // TakeRecieveQuota reduces the receive quota by 1. | ||||
| func (i *Inflight) TakeReceiveQuota() { | ||||
| 	if atomic.LoadInt32(&i.receiveQuota) > 0 { | ||||
| 		atomic.AddInt32(&i.receiveQuota, -1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TakeRecieveQuota increases the receive quota by 1. | ||||
| func (i *Inflight) ReturnReceiveQuota() { | ||||
| 	if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { | ||||
| 		atomic.AddInt32(&i.receiveQuota, 1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ResetReceiveQuota resets the receive quota to the maximum allowed value. | ||||
| func (i *Inflight) ResetReceiveQuota(n int32) { | ||||
| 	atomic.StoreInt32(&i.receiveQuota, n) | ||||
| 	atomic.StoreInt32(&i.maximumReceiveQuota, n) | ||||
| } | ||||
|  | ||||
| // TakeSendQuota reduces the send quota by 1. | ||||
| func (i *Inflight) TakeSendQuota() { | ||||
| 	if atomic.LoadInt32(&i.sendQuota) > 0 { | ||||
| 		atomic.AddInt32(&i.sendQuota, -1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReturnSendQuota increases the send quota by 1. | ||||
| func (i *Inflight) ReturnSendQuota() { | ||||
| 	if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { | ||||
| 		atomic.AddInt32(&i.sendQuota, 1) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ResetSendQuota resets the send quota to the maximum allowed value. | ||||
| func (i *Inflight) ResetSendQuota(n int32) { | ||||
| 	atomic.StoreInt32(&i.sendQuota, n) | ||||
| 	atomic.StoreInt32(&i.maximumSendQuota, n) | ||||
| } | ||||
							
								
								
									
										189
									
								
								inflight_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								inflight_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package mqtt | ||||
|  | ||||
| import ( | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/packets" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestInflightSet(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) | ||||
| 	require.True(t, r) | ||||
| 	require.NotNil(t, cl.State.Inflight.internal[1]) | ||||
| 	require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID) | ||||
|  | ||||
| 	r = cl.State.Inflight.Set(packets.Packet{PacketID: 1}) | ||||
| 	require.False(t, r) | ||||
| } | ||||
|  | ||||
| func TestInflightGet(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||
|  | ||||
| 	msg, ok := cl.State.Inflight.Get(2) | ||||
| 	require.True(t, ok) | ||||
| 	require.NotEqual(t, 0, msg.PacketID) | ||||
| } | ||||
|  | ||||
| func TestInflightGetAllAndImmediate(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) | ||||
|  | ||||
| 	require.Equal(t, []packets.Packet{ | ||||
| 		{PacketID: 1, Created: 1}, | ||||
| 		{PacketID: 2, Created: 2}, | ||||
| 		{PacketID: 3, Created: 3, Expiry: -1}, | ||||
| 		{PacketID: 4, Created: 4, Expiry: -1}, | ||||
| 		{PacketID: 5, Created: 5}, | ||||
| 	}, cl.State.Inflight.GetAll(false)) | ||||
|  | ||||
| 	require.Equal(t, []packets.Packet{ | ||||
| 		{PacketID: 3, Created: 3, Expiry: -1}, | ||||
| 		{PacketID: 4, Created: 4, Expiry: -1}, | ||||
| 	}, cl.State.Inflight.GetAll(true)) | ||||
| } | ||||
|  | ||||
| func TestInflightLen(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2}) | ||||
| 	require.Equal(t, 1, cl.State.Inflight.Len()) | ||||
| } | ||||
|  | ||||
| func TestInflightDelete(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
|  | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3}) | ||||
| 	require.NotNil(t, cl.State.Inflight.internal[3]) | ||||
|  | ||||
| 	r := cl.State.Inflight.Delete(3) | ||||
| 	require.True(t, r) | ||||
| 	require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID) | ||||
|  | ||||
| 	_, ok := cl.State.Inflight.Get(3) | ||||
| 	require.False(t, ok) | ||||
|  | ||||
| 	r = cl.State.Inflight.Delete(3) | ||||
| 	require.False(t, r) | ||||
| } | ||||
|  | ||||
| func TestResetReceiveQuota(t *testing.T) { | ||||
| 	i := NewInflights() | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||
| 	i.ResetReceiveQuota(6) | ||||
| 	require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota)) | ||||
| } | ||||
|  | ||||
| func TestReceiveQuota(t *testing.T) { | ||||
| 	i := NewInflights() | ||||
| 	i.receiveQuota = 4 | ||||
| 	i.maximumReceiveQuota = 5 | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Return 1 | ||||
| 	i.ReturnReceiveQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Try to go over max limit | ||||
| 	i.ReturnReceiveQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Reset to max 1 | ||||
| 	i.ResetReceiveQuota(1) | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Take 1 | ||||
| 	i.TakeReceiveQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||
|  | ||||
| 	// Try to go below zero | ||||
| 	i.TakeReceiveQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) | ||||
| } | ||||
|  | ||||
| func TestResetSendQuota(t *testing.T) { | ||||
| 	i := NewInflights() | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||
| 	i.ResetSendQuota(6) | ||||
| 	require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota)) | ||||
| } | ||||
|  | ||||
| func TestSendQuota(t *testing.T) { | ||||
| 	i := NewInflights() | ||||
| 	i.sendQuota = 4 | ||||
| 	i.maximumSendQuota = 5 | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Return 1 | ||||
| 	i.ReturnSendQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Try to go over max limit | ||||
| 	i.ReturnSendQuota() | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Reset to max 1 | ||||
| 	i.ResetSendQuota(1) | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Take 1 | ||||
| 	i.TakeSendQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||
|  | ||||
| 	// Try to go below zero | ||||
| 	i.TakeSendQuota() | ||||
| 	require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) | ||||
| 	require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) | ||||
| } | ||||
|  | ||||
| func TestNextImmediate(t *testing.T) { | ||||
| 	cl, _, _ := newTestClient() | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) | ||||
| 	cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) | ||||
|  | ||||
| 	pk, ok := cl.State.Inflight.NextImmediate() | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk) | ||||
|  | ||||
| 	r := cl.State.Inflight.Delete(3) | ||||
| 	require.True(t, r) | ||||
|  | ||||
| 	pk, ok = cl.State.Inflight.NextImmediate() | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk) | ||||
|  | ||||
| 	r = cl.State.Inflight.Delete(4) | ||||
| 	require.True(t, r) | ||||
|  | ||||
| 	_, ok = cl.State.Inflight.NextImmediate() | ||||
| 	require.False(t, ok) | ||||
| } | ||||
							
								
								
									
										139
									
								
								listeners/http_sysinfo.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								listeners/http_sysinfo.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. | ||||
| type HTTPStats struct { | ||||
| 	sync.RWMutex | ||||
| 	id      string          // the internal id of the listener | ||||
| 	address string          // the network address to bind to | ||||
| 	config  *Config         // configuration values for the listener | ||||
| 	listen  *http.Server    // the http server | ||||
| 	log     *zerolog.Logger // server logger | ||||
| 	sysInfo *system.Info    // pointers to the server data | ||||
| 	end     uint32          // ensure the close methods are only called once | ||||
| } | ||||
|  | ||||
| // NewHTTPStats initialises and returns a new HTTP listener, listening on an address. | ||||
| func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats { | ||||
| 	if config == nil { | ||||
| 		config = new(Config) | ||||
| 	} | ||||
| 	return &HTTPStats{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 		sysInfo: sysInfo, | ||||
| 		config:  config, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *HTTPStats) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *HTTPStats) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
|  | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *HTTPStats) Protocol() string { | ||||
| 	if l.listen != nil && l.listen.TLSConfig != nil { | ||||
| 		return "https" | ||||
| 	} | ||||
|  | ||||
| 	return "http" | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *HTTPStats) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
|  | ||||
| 	mux := http.NewServeMux() | ||||
| 	mux.HandleFunc("/", l.jsonHandler) | ||||
| 	l.listen = &http.Server{ | ||||
| 		ReadTimeout:  5 * time.Second, | ||||
| 		WriteTimeout: 5 * time.Second, | ||||
| 		Addr:         l.address, | ||||
| 		Handler:      mux, | ||||
| 	} | ||||
|  | ||||
| 	if l.config.TLSConfig != nil { | ||||
| 		l.listen.TLSConfig = l.config.TLSConfig | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Serve starts listening for new connections and serving responses. | ||||
| func (l *HTTPStats) Serve(establish EstablishFn) { | ||||
| 	if l.listen.TLSConfig != nil { | ||||
| 		l.listen.ListenAndServeTLS("", "") | ||||
| 	} else { | ||||
| 		l.listen.ListenAndServe() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *HTTPStats) Close(closeClients CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 		defer cancel() | ||||
| 		l.listen.Shutdown(ctx) | ||||
| 	} | ||||
|  | ||||
| 	closeClients(l.id) | ||||
| } | ||||
|  | ||||
| // jsonHandler is an HTTP handler which outputs the $SYS stats as JSON. | ||||
| func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { | ||||
| 	info := &system.Info{ | ||||
| 		Version:             l.sysInfo.Version, | ||||
| 		Started:             atomic.LoadInt64(&l.sysInfo.Started), | ||||
| 		Time:                atomic.LoadInt64(&l.sysInfo.Time), | ||||
| 		Uptime:              atomic.LoadInt64(&l.sysInfo.Uptime), | ||||
| 		BytesReceived:       atomic.LoadInt64(&l.sysInfo.BytesReceived), | ||||
| 		BytesSent:           atomic.LoadInt64(&l.sysInfo.BytesSent), | ||||
| 		ClientsConnected:    atomic.LoadInt64(&l.sysInfo.ClientsConnected), | ||||
| 		ClientsMaximum:      atomic.LoadInt64(&l.sysInfo.ClientsMaximum), | ||||
| 		ClientsTotal:        atomic.LoadInt64(&l.sysInfo.ClientsTotal), | ||||
| 		ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected), | ||||
| 		MessagesReceived:    atomic.LoadInt64(&l.sysInfo.MessagesReceived), | ||||
| 		MessagesSent:        atomic.LoadInt64(&l.sysInfo.MessagesSent), | ||||
| 		InflightDropped:     atomic.LoadInt64(&l.sysInfo.InflightDropped), | ||||
| 		Subscriptions:       atomic.LoadInt64(&l.sysInfo.Subscriptions), | ||||
| 		PacketsReceived:     atomic.LoadInt64(&l.sysInfo.PacketsReceived), | ||||
| 		PacketsSent:         atomic.LoadInt64(&l.sysInfo.PacketsSent), | ||||
| 		Retained:            atomic.LoadInt64(&l.sysInfo.Retained), | ||||
| 		Inflight:            atomic.LoadInt64(&l.sysInfo.Inflight), | ||||
| 		MemoryAlloc:         atomic.LoadInt64(&l.sysInfo.MemoryAlloc), | ||||
| 		Threads:             atomic.LoadInt64(&l.sysInfo.Threads), | ||||
| 	} | ||||
|  | ||||
| 	out, err := json.MarshalIndent(info, "", "\t") | ||||
| 	if err != nil { | ||||
| 		io.WriteString(w, err.Error()) | ||||
| 	} | ||||
|  | ||||
| 	w.Write(out) | ||||
| } | ||||
							
								
								
									
										127
									
								
								listeners/http_sysinfo_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								listeners/http_sysinfo_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/v2/system" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewHTTPStats(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, nil) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| 	require.Equal(t, testAddr, l.address) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsID(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, nil) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsAddress(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, nil) | ||||
| 	require.Equal(t, testAddr, l.Address()) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsProtocol(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, nil) | ||||
| 	require.Equal(t, "http", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsTLSProtocol(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}, nil) | ||||
|  | ||||
| 	l.Init(nil) | ||||
| 	require.Equal(t, "https", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsInit(t *testing.T) { | ||||
| 	sysInfo := new(system.Info) | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, sysInfo) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.NotNil(t, l.sysInfo) | ||||
| 	require.Equal(t, sysInfo, l.sysInfo) | ||||
| 	require.NotNil(t, l.listen) | ||||
| 	require.Equal(t, testAddr, l.listen.Addr) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsServeAndClose(t *testing.T) { | ||||
| 	sysInfo := &system.Info{ | ||||
| 		Version: "test", | ||||
| 	} | ||||
|  | ||||
| 	// setup http stats listener | ||||
| 	l := NewHTTPStats("t1", testAddr, nil, sysInfo) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	// get body from stats address | ||||
| 	resp, err := http.Get("http://localhost" + testAddr) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, resp) | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// decode body from json and check data | ||||
| 	v := new(system.Info) | ||||
| 	err = json.Unmarshal(body, v) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "test", v.Version) | ||||
|  | ||||
| 	// ensure listening is closed | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.Equal(t, true, closed) | ||||
|  | ||||
| 	_, err = http.Get("http://localhost" + testAddr) | ||||
| 	require.Error(t, err) | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsServeTLSAndClose(t *testing.T) { | ||||
| 	sysInfo := &system.Info{ | ||||
| 		Version: "test", | ||||
| 	} | ||||
|  | ||||
| 	l := NewHTTPStats("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}, sysInfo) | ||||
|  | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	l.Close(MockCloser) | ||||
| } | ||||
| @@ -1,94 +1,91 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| 
 | ||||
| package listeners | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/system" | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
| 
 | ||||
| // Config contains configuration values for a listener. | ||||
| type Config struct { | ||||
| 	Auth auth.Controller // an authentication controller containing auth and ACL logic. | ||||
| 	TLS  *TLS            // the TLS certficates and settings for the connection. | ||||
| 	// TLSConfig is a tls.Config configuration to be used with the listener. | ||||
| 	// See examples folder for basic and mutual-tls use. | ||||
| 	TLSConfig *tls.Config | ||||
| } | ||||
| 
 | ||||
| // TLS contains the TLS certificates and settings for the listener connection. | ||||
| type TLS struct { | ||||
| 	Certificate []byte // the body of a public certificate. | ||||
| 	PrivateKey  []byte // the body of a private key. | ||||
| } | ||||
| 
 | ||||
| // EstablishFunc is a callback function for establishing new clients. | ||||
| type EstablishFunc func(id string, c net.Conn, ac auth.Controller) error | ||||
| // EstablishFn is a callback function for establishing new clients. | ||||
| type EstablishFn func(id string, c net.Conn) error | ||||
| 
 | ||||
| // CloseFunc is a callback function for closing all listener clients. | ||||
| type CloseFunc func(id string) | ||||
| type CloseFn func(id string) | ||||
| 
 | ||||
| // Listener is an interface for network listeners. A network listener listens | ||||
| // for incoming client connections and adds them to the server. | ||||
| type Listener interface { | ||||
| 	SetConfig(*Config)           // set the listener config. | ||||
| 	Listen(s *system.Info) error // open the network address. | ||||
| 	Serve(EstablishFunc)         // starting actively listening for new connections. | ||||
| 	ID() string                  // return the id of the listener. | ||||
| 	Close(CloseFunc)             // stop and close the listener. | ||||
| 	Init(*zerolog.Logger) error // open the network address | ||||
| 	Serve(EstablishFn)          // starting actively listening for new connections | ||||
| 	ID() string                 // return the id of the listener | ||||
| 	Address() string            // the address of the listener | ||||
| 	Protocol() string           // the protocol in use by the listener | ||||
| 	Close(CloseFn)              // stop and close the listener | ||||
| } | ||||
| 
 | ||||
| // Listeners contains the network listeners for the broker. | ||||
| type Listeners struct { | ||||
| 	sync.RWMutex | ||||
| 	wg       sync.WaitGroup      // a waitgroup that waits for all listeners to finish. | ||||
| 	internal map[string]Listener // a map of active listeners. | ||||
| 	system   *system.Info        // pointers to system info. | ||||
| 	sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| // New returns a new instance of Listeners. | ||||
| func New(s *system.Info) *Listeners { | ||||
| func New() *Listeners { | ||||
| 	return &Listeners{ | ||||
| 		internal: map[string]Listener{}, | ||||
| 		system:   s, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Add adds a new listener to the listeners map, keyed on id. | ||||
| func (l *Listeners) Add(val Listener) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	l.internal[val.ID()] = val | ||||
| 	l.Unlock() | ||||
| } | ||||
| 
 | ||||
| // Get returns the value of a listener if it exists. | ||||
| func (l *Listeners) Get(id string) (Listener, bool) { | ||||
| 	l.RLock() | ||||
| 	defer l.RUnlock() | ||||
| 	val, ok := l.internal[id] | ||||
| 	l.RUnlock() | ||||
| 	return val, ok | ||||
| } | ||||
| 
 | ||||
| // Len returns the length of the listeners map. | ||||
| func (l *Listeners) Len() int { | ||||
| 	l.RLock() | ||||
| 	val := len(l.internal) | ||||
| 	l.RUnlock() | ||||
| 	return val | ||||
| 	defer l.RUnlock() | ||||
| 	return len(l.internal) | ||||
| } | ||||
| 
 | ||||
| // Delete removes a listener from the internal map. | ||||
| func (l *Listeners) Delete(id string) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	delete(l.internal, id) | ||||
| 	l.Unlock() | ||||
| } | ||||
| 
 | ||||
| // Serve starts a listener serving from the internal map. | ||||
| func (l *Listeners) Serve(id string, establisher EstablishFunc) { | ||||
| func (l *Listeners) Serve(id string, establisher EstablishFn) { | ||||
| 	l.RLock() | ||||
| 	defer l.RUnlock() | ||||
| 	listener := l.internal[id] | ||||
| 	l.RUnlock() | ||||
| 
 | ||||
| 	go func(e EstablishFunc) { | ||||
| 	go func(e EstablishFn) { | ||||
| 		defer l.wg.Done() | ||||
| 		l.wg.Add(1) | ||||
| 		listener.Serve(e) | ||||
| @@ -96,7 +93,7 @@ func (l *Listeners) Serve(id string, establisher EstablishFunc) { | ||||
| } | ||||
| 
 | ||||
| // ServeAll starts all listeners serving from the internal map. | ||||
| func (l *Listeners) ServeAll(establisher EstablishFunc) { | ||||
| func (l *Listeners) ServeAll(establisher EstablishFn) { | ||||
| 	l.RLock() | ||||
| 	i := 0 | ||||
| 	ids := make([]string, len(l.internal)) | ||||
| @@ -112,15 +109,16 @@ func (l *Listeners) ServeAll(establisher EstablishFunc) { | ||||
| } | ||||
| 
 | ||||
| // Close stops a listener from the internal map. | ||||
| func (l *Listeners) Close(id string, closer CloseFunc) { | ||||
| func (l *Listeners) Close(id string, closer CloseFn) { | ||||
| 	l.RLock() | ||||
| 	listener := l.internal[id] | ||||
| 	l.RUnlock() | ||||
| 	listener.Close(closer) | ||||
| 	defer l.RUnlock() | ||||
| 	if listener, ok := l.internal[id]; ok { | ||||
| 		listener.Close(closer) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CloseAll iterates and closes all registered listeners. | ||||
| func (l *Listeners) CloseAll(closer CloseFunc) { | ||||
| func (l *Listeners) CloseAll(closer CloseFn) { | ||||
| 	l.RLock() | ||||
| 	i := 0 | ||||
| 	ids := make([]string, len(l.internal)) | ||||
							
								
								
									
										177
									
								
								listeners/listeners_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								listeners/listeners_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,177 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| const testAddr = ":22222" | ||||
|  | ||||
| var ( | ||||
| 	logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled) | ||||
|  | ||||
| 	testCertificate = []byte(`-----BEGIN CERTIFICATE----- | ||||
| MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB | ||||
| VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV | ||||
| BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD | ||||
| VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x | ||||
| DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3 | ||||
| AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi | ||||
| OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI | ||||
| MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD | ||||
| gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ | ||||
| qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy | ||||
| zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw= | ||||
| -----END CERTIFICATE-----`) | ||||
|  | ||||
| 	testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- | ||||
| MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o | ||||
| FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA | ||||
| rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB | ||||
| AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K | ||||
| UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m | ||||
| n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ | ||||
| mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6 | ||||
| INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z | ||||
| AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt | ||||
| /F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32 | ||||
| WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy | ||||
| w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3 | ||||
| OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc= | ||||
| -----END RSA PRIVATE KEY-----`) | ||||
|  | ||||
| 	tlsConfigBasic *tls.Config | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	cert, err := tls.X509KeyPair(testCertificate, testPrivateKey) | ||||
| 	if err != nil { | ||||
| 		log.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Basic TLS Config | ||||
| 	tlsConfigBasic = &tls.Config{ | ||||
| 		MinVersion:   tls.VersionTLS12, | ||||
| 		Certificates: []tls.Certificate{cert}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNew(t *testing.T) { | ||||
| 	l := New() | ||||
| 	require.NotNil(t, l.internal) | ||||
| } | ||||
|  | ||||
| func TestAddListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	require.Contains(t, l.internal, "t1") | ||||
| } | ||||
|  | ||||
| func TestGetListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	l.Add(NewMockListener("t2", testAddr)) | ||||
| 	require.Contains(t, l.internal, "t1") | ||||
| 	require.Contains(t, l.internal, "t2") | ||||
|  | ||||
| 	g, ok := l.Get("t1") | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, g.ID(), "t1") | ||||
| } | ||||
|  | ||||
| func TestLenListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	l.Add(NewMockListener("t2", testAddr)) | ||||
| 	require.Contains(t, l.internal, "t1") | ||||
| 	require.Contains(t, l.internal, "t2") | ||||
| 	require.Equal(t, 2, l.Len()) | ||||
| } | ||||
|  | ||||
| func TestDeleteListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	require.Contains(t, l.internal, "t1") | ||||
| 	l.Delete("t1") | ||||
| 	_, ok := l.Get("t1") | ||||
| 	require.False(t, ok) | ||||
| 	require.Nil(t, l.internal["t1"]) | ||||
| } | ||||
|  | ||||
| func TestServeListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	l.Serve("t1", MockEstablisher) | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	require.True(t, l.internal["t1"].(*MockListener).IsServing()) | ||||
|  | ||||
| 	l.Close("t1", MockCloser) | ||||
| 	require.False(t, l.internal["t1"].(*MockListener).IsServing()) | ||||
| } | ||||
|  | ||||
| func TestServeAllListeners(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	l.Add(NewMockListener("t2", testAddr)) | ||||
| 	l.Add(NewMockListener("t3", testAddr)) | ||||
| 	l.ServeAll(MockEstablisher) | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	require.True(t, l.internal["t1"].(*MockListener).IsServing()) | ||||
| 	require.True(t, l.internal["t2"].(*MockListener).IsServing()) | ||||
| 	require.True(t, l.internal["t3"].(*MockListener).IsServing()) | ||||
|  | ||||
| 	l.Close("t1", MockCloser) | ||||
| 	l.Close("t2", MockCloser) | ||||
| 	l.Close("t3", MockCloser) | ||||
|  | ||||
| 	require.False(t, l.internal["t1"].(*MockListener).IsServing()) | ||||
| 	require.False(t, l.internal["t2"].(*MockListener).IsServing()) | ||||
| 	require.False(t, l.internal["t3"].(*MockListener).IsServing()) | ||||
| } | ||||
|  | ||||
| func TestCloseListener(t *testing.T) { | ||||
| 	l := New() | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	l.Add(mocked) | ||||
| 	l.Serve("t1", MockEstablisher) | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	var closed bool | ||||
| 	l.Close("t1", func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
| 	require.True(t, closed) | ||||
| } | ||||
|  | ||||
| func TestCloseAllListeners(t *testing.T) { | ||||
| 	l := New() | ||||
| 	l.Add(NewMockListener("t1", testAddr)) | ||||
| 	l.Add(NewMockListener("t2", testAddr)) | ||||
| 	l.Add(NewMockListener("t3", testAddr)) | ||||
| 	l.ServeAll(MockEstablisher) | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	require.True(t, l.internal["t1"].(*MockListener).IsServing()) | ||||
| 	require.True(t, l.internal["t2"].(*MockListener).IsServing()) | ||||
| 	require.True(t, l.internal["t3"].(*MockListener).IsServing()) | ||||
|  | ||||
| 	closed := make(map[string]bool) | ||||
| 	l.CloseAll(func(id string) { | ||||
| 		closed[id] = true | ||||
| 	}) | ||||
| 	require.Contains(t, closed, "t1") | ||||
| 	require.Contains(t, closed, "t2") | ||||
| 	require.Contains(t, closed, "t3") | ||||
| 	require.True(t, closed["t1"]) | ||||
| 	require.True(t, closed["t2"]) | ||||
| 	require.True(t, closed["t3"]) | ||||
| } | ||||
| @@ -1,36 +1,38 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| 
 | ||||
| package listeners | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/system" | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
| 
 | ||||
| // MockEstablisher is a function signature which can be used in testing. | ||||
| func MockEstablisher(id string, c net.Conn) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // MockCloser is a function signature which can be used in testing. | ||||
| func MockCloser(id string) {} | ||||
| 
 | ||||
| // MockEstablisher is a function signature which can be used in testing. | ||||
| func MockEstablisher(id string, c net.Conn, ac auth.Controller) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // MockListener is a mock listener for establishing client connections. | ||||
| type MockListener struct { | ||||
| 	sync.RWMutex | ||||
| 	id        string    // the id of the listener. | ||||
| 	Config    *Config   // configuration for the listener. | ||||
| 	address   string    // the network address the listener binds to. | ||||
| 	Listening bool      // indiciate the listener is listening. | ||||
| 	Serving   bool      // indicate the listener is serving. | ||||
| 	done      chan bool // indicate the listener is done. | ||||
| 	ErrListen bool      // throw an error on listen. | ||||
| 	id        string    // the id of the listener | ||||
| 	address   string    // the network address the listener binds to | ||||
| 	Config    *Config   // configuration for the listener | ||||
| 	done      chan bool // indicate the listener is done | ||||
| 	Serving   bool      // indicate the listener is serving | ||||
| 	Listening bool      // indiciate the listener is listening | ||||
| 	ErrListen bool      // throw an error on listen | ||||
| } | ||||
| 
 | ||||
| // NewMockListener returns a new instance of MockListener | ||||
| // NewMockListener returns a new instance of MockListener. | ||||
| func NewMockListener(id, address string) *MockListener { | ||||
| 	return &MockListener{ | ||||
| 		id:      id, | ||||
| @@ -40,47 +42,45 @@ func NewMockListener(id, address string) *MockListener { | ||||
| } | ||||
| 
 | ||||
| // Serve serves the mock listener. | ||||
| func (l *MockListener) Serve(establisher EstablishFunc) { | ||||
| func (l *MockListener) Serve(establisher EstablishFn) { | ||||
| 	l.Lock() | ||||
| 	l.Serving = true | ||||
| 	l.Unlock() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-l.done: | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 	for range l.done { | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SetConfig sets the configuration values of the mock listener. | ||||
| func (l *MockListener) Listen(s *system.Info) error { | ||||
| // Init initializes the listener. | ||||
| func (l *MockListener) Init(log *zerolog.Logger) error { | ||||
| 	if l.ErrListen { | ||||
| 		return fmt.Errorf("listen failure") | ||||
| 	} | ||||
| 
 | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	l.Listening = true | ||||
| 	l.Unlock() | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SetConfig sets the configuration values of the mock listener. | ||||
| func (l *MockListener) SetConfig(config *Config) { | ||||
| 	l.Lock() | ||||
| 	l.Config = config | ||||
| 	l.Unlock() | ||||
| } | ||||
| 
 | ||||
| // ID returns the id of the mock listener. | ||||
| func (l *MockListener) ID() string { | ||||
| 	l.RLock() | ||||
| 	id := l.id | ||||
| 	l.RUnlock() | ||||
| 	return id | ||||
| 	return l.id | ||||
| } | ||||
| 
 | ||||
| // Address returns the address of the listener. | ||||
| func (l *MockListener) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
| 
 | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *MockListener) Protocol() string { | ||||
| 	return "mock" | ||||
| } | ||||
| 
 | ||||
| // Close closes the mock listener. | ||||
| func (l *MockListener) Close(closer CloseFunc) { | ||||
| func (l *MockListener) Close(closer CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| 	l.Serving = false | ||||
| @@ -95,7 +95,7 @@ func (l *MockListener) IsServing() bool { | ||||
| 	return l.Serving | ||||
| } | ||||
| 
 | ||||
| // IsServing indicates whether the mock listener is listening. | ||||
| // IsListening indicates whether the mock listener is listening. | ||||
| func (l *MockListener) IsListening() bool { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
| @@ -1,3 +1,7 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| 
 | ||||
| package listeners | ||||
| 
 | ||||
| import ( | ||||
| @@ -6,42 +10,64 @@ import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 
 | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| ) | ||||
| 
 | ||||
| func TestMockEstablisher(t *testing.T) { | ||||
| 	_, w := net.Pipe() | ||||
| 	err := MockEstablisher("t1", w, new(auth.Allow)) | ||||
| 	err := MockEstablisher("t1", w) | ||||
| 	require.NoError(t, err) | ||||
| 	w.Close() | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListener(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, "t1", mocked.id) | ||||
| 	require.Equal(t, ":1882", mocked.address) | ||||
| 	require.Equal(t, testAddr, mocked.address) | ||||
| } | ||||
| func TestMockListenerID(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, "t1", mocked.ID()) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerListen(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| func TestMockListenerAddress(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, testAddr, mocked.Address()) | ||||
| } | ||||
| func TestMockListenerProtocol(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, "mock", mocked.Protocol()) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerIsListening(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, false, mocked.IsListening()) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerIsServing(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, false, mocked.IsServing()) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerInit(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, "t1", mocked.id) | ||||
| 	require.Equal(t, ":1882", mocked.address) | ||||
| 	require.Equal(t, testAddr, mocked.address) | ||||
| 
 | ||||
| 	require.Equal(t, false, mocked.IsListening()) | ||||
| 	err := mocked.Listen(nil) | ||||
| 	err := mocked.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, true, mocked.IsListening()) | ||||
| } | ||||
| func TestNewMockListenerListenFailure(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 
 | ||||
| func TestNewMockListenerInitFailure(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	mocked.ErrListen = true | ||||
| 	err := mocked.Listen(nil) | ||||
| 	err := mocked.Init(nil) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
| 
 | ||||
| func TestMockListenerServe(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	require.Equal(t, false, mocked.IsServing()) | ||||
| 
 | ||||
| 	o := make(chan bool) | ||||
| @@ -60,30 +86,14 @@ func TestMockListenerServe(t *testing.T) { | ||||
| 	require.Equal(t, true, closed) | ||||
| 	<-o | ||||
| 
 | ||||
| 	mocked.Listen(nil) | ||||
| } | ||||
| 
 | ||||
| func TestMockListenerSetConfig(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1883") | ||||
| 	mocked.SetConfig(new(Config)) | ||||
| 	require.NotNil(t, mocked.Config) | ||||
| 	mocked.Init(nil) | ||||
| } | ||||
| 
 | ||||
| func TestMockListenerClose(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 	mocked := NewMockListener("t1", testAddr) | ||||
| 	var closed bool | ||||
| 	mocked.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
| 	require.Equal(t, true, closed) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerIsListening(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 	require.Equal(t, false, mocked.IsListening()) | ||||
| } | ||||
| 
 | ||||
| func TestNewMockListenerIsServing(t *testing.T) { | ||||
| 	mocked := NewMockListener("t1", ":1882") | ||||
| 	require.Equal(t, false, mocked.IsServing()) | ||||
| } | ||||
							
								
								
									
										108
									
								
								listeners/tcp.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								listeners/tcp.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| // TCP is a listener for establishing client connections on basic TCP protocol. | ||||
| type TCP struct { // [MQTT-4.2.0-1] | ||||
| 	sync.RWMutex | ||||
| 	id      string          // the internal id of the listener | ||||
| 	address string          // the network address to bind to | ||||
| 	listen  net.Listener    // a net.Listener which will listen for new clients | ||||
| 	config  *Config         // configuration values for the listener | ||||
| 	log     *zerolog.Logger // server logger | ||||
| 	end     uint32          // ensure the close methods are only called once | ||||
| } | ||||
|  | ||||
| // NewTCP initialises and returns a new TCP listener, listening on an address. | ||||
| func NewTCP(id, address string, config *Config) *TCP { | ||||
| 	if config == nil { | ||||
| 		config = new(Config) | ||||
| 	} | ||||
|  | ||||
| 	return &TCP{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 		config:  config, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *TCP) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *TCP) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
|  | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *TCP) Protocol() string { | ||||
| 	return "tcp" | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *TCP) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
|  | ||||
| 	var err error | ||||
| 	if l.config.TLSConfig != nil { | ||||
| 		l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig) | ||||
| 	} else { | ||||
| 		l.listen, err = net.Listen("tcp", l.address) | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Serve starts waiting for new TCP connections, and calls the establish | ||||
| // connection callback for any received. | ||||
| func (l *TCP) Serve(establish EstablishFn) { | ||||
| 	for { | ||||
| 		if atomic.LoadUint32(&l.end) == 1 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		conn, err := l.listen.Accept() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		if atomic.LoadUint32(&l.end) == 0 { | ||||
| 			go func() { | ||||
| 				err = establish(l.id, conn) | ||||
| 				if err != nil { | ||||
| 					l.log.Warn().Err(err).Send() | ||||
| 				} | ||||
| 			}() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *TCP) Close(closeClients CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		closeClients(l.id) | ||||
| 	} | ||||
|  | ||||
| 	if l.listen != nil { | ||||
| 		err := l.listen.Close() | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										131
									
								
								listeners/tcp_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								listeners/tcp_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewTCP(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| 	require.Equal(t, testAddr, l.address) | ||||
| } | ||||
|  | ||||
| func TestTCPID(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestTCPAddress(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	require.Equal(t, testAddr, l.Address()) | ||||
| } | ||||
|  | ||||
| func TestTCPProtocol(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	require.Equal(t, "tcp", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestTCPProtocolTLS(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
|  | ||||
| 	l.Init(&logger) | ||||
| 	defer l.listen.Close() | ||||
| 	require.Equal(t, "tcp", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestTCPInit(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	err := l.Init(&logger) | ||||
| 	l.Close(MockCloser) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	l2 := NewTCP("t2", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
| 	err = l2.Init(&logger) | ||||
| 	l2.Close(MockCloser) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, l2.config.TLSConfig) | ||||
| } | ||||
|  | ||||
| func TestTCPServeAndClose(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	err := l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.True(t, closed) | ||||
| 	<-o | ||||
|  | ||||
| 	l.Close(MockCloser)      // coverage: close closed | ||||
| 	l.Serve(MockEstablisher) // coverage: serve closed | ||||
| } | ||||
|  | ||||
| func TestTCPServeTLSAndClose(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
| 	err := l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.Equal(t, true, closed) | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestTCPEstablishThenEnd(t *testing.T) { | ||||
| 	l := NewTCP("t1", testAddr, nil) | ||||
| 	err := l.Init(&logger) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	established := make(chan bool) | ||||
| 	go func() { | ||||
| 		l.Serve(func(id string, c net.Conn) error { | ||||
| 			established <- true | ||||
| 			return errors.New("ending") // return an error to exit immediately | ||||
| 		}) | ||||
| 		o <- true | ||||
| 	}() | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	net.Dial("tcp", l.listen.Addr().String()) | ||||
| 	require.Equal(t, true, <-established) | ||||
| 	l.Close(MockCloser) | ||||
| 	<-o | ||||
| } | ||||
							
								
								
									
										178
									
								
								listeners/websocket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								listeners/websocket.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// ErrInvalidMessage indicates that a message payload was not valid. | ||||
| 	ErrInvalidMessage = errors.New("message type not binary") | ||||
| ) | ||||
|  | ||||
| // Websocket is a listener for establishing websocket connections. | ||||
| type Websocket struct { // [MQTT-4.2.0-1] | ||||
| 	sync.RWMutex | ||||
| 	id        string              // the internal id of the listener | ||||
| 	address   string              // the network address to bind to | ||||
| 	config    *Config             // configuration values for the listener | ||||
| 	listen    *http.Server        // an http server for serving websocket connections | ||||
| 	log       *zerolog.Logger     // server logger | ||||
| 	establish EstablishFn         // the server's establish connection handler | ||||
| 	upgrader  *websocket.Upgrader //  upgrade the incoming http/tcp connection to a websocket compliant connection. | ||||
| 	end       uint32              // ensure the close methods are only called once | ||||
| } | ||||
|  | ||||
| // NewWebsocket initialises and returns a new Websocket listener, listening on an address. | ||||
| func NewWebsocket(id, address string, config *Config) *Websocket { | ||||
| 	if config == nil { | ||||
| 		config = new(Config) | ||||
| 	} | ||||
|  | ||||
| 	return &Websocket{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 		config:  config, | ||||
| 		upgrader: &websocket.Upgrader{ | ||||
| 			Subprotocols: []string{"mqtt"}, | ||||
| 			CheckOrigin: func(r *http.Request) bool { | ||||
| 				return true | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *Websocket) ID() string { | ||||
| 	return l.id | ||||
| } | ||||
|  | ||||
| // Address returns the address of the listener. | ||||
| func (l *Websocket) Address() string { | ||||
| 	return l.address | ||||
| } | ||||
|  | ||||
| // Protocol returns the address of the listener. | ||||
| func (l *Websocket) Protocol() string { | ||||
| 	if l.config.TLSConfig != nil { | ||||
| 		return "wss" | ||||
| 	} | ||||
|  | ||||
| 	return "ws" | ||||
| } | ||||
|  | ||||
| // Init initializes the listener. | ||||
| func (l *Websocket) Init(log *zerolog.Logger) error { | ||||
| 	l.log = log | ||||
|  | ||||
| 	mux := http.NewServeMux() | ||||
| 	mux.HandleFunc("/", l.handler) | ||||
| 	l.listen = &http.Server{ | ||||
| 		Addr:         l.address, | ||||
| 		Handler:      mux, | ||||
| 		TLSConfig:    l.config.TLSConfig, | ||||
| 		ReadTimeout:  60 * time.Second, | ||||
| 		WriteTimeout: 60 * time.Second, | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // handler upgrades and handles an incoming websocket connection. | ||||
| func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { | ||||
| 	c, err := l.upgrader.Upgrade(w, r, nil) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	defer c.Close() | ||||
|  | ||||
| 	err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c}) | ||||
| 	if err != nil { | ||||
| 		l.log.Warn().Err(err).Send() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Serve starts waiting for new Websocket connections, and calls the connection | ||||
| // establishment callback for any received. | ||||
| func (l *Websocket) Serve(establish EstablishFn) { | ||||
| 	l.establish = establish | ||||
|  | ||||
| 	if l.listen.TLSConfig != nil { | ||||
| 		l.listen.ListenAndServeTLS("", "") | ||||
| 	} else { | ||||
| 		l.listen.ListenAndServe() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *Websocket) Close(closeClients CloseFn) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.CompareAndSwapUint32(&l.end, 0, 1) { | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 		defer cancel() | ||||
| 		l.listen.Shutdown(ctx) | ||||
| 	} | ||||
|  | ||||
| 	closeClients(l.id) | ||||
| } | ||||
|  | ||||
| // wsConn is a websocket connection which satisfies the net.Conn interface. | ||||
| type wsConn struct { | ||||
| 	net.Conn | ||||
| 	c *websocket.Conn | ||||
| } | ||||
|  | ||||
| // Read reads the next span of bytes from the websocket connection and returns the number of bytes read. | ||||
| func (ws *wsConn) Read(p []byte) (int, error) { | ||||
| 	op, r, err := ws.c.NextReader() | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	if op != websocket.BinaryMessage { | ||||
| 		err = ErrInvalidMessage | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	var n, br int | ||||
| 	for { | ||||
| 		br, err = r.Read(p[n:]) | ||||
| 		n += br | ||||
| 		if err != nil { | ||||
| 			if err == io.EOF { | ||||
| 				err = nil | ||||
| 			} | ||||
| 			return n, err | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Write writes bytes to the websocket connection. | ||||
| func (ws *wsConn) Write(p []byte) (int, error) { | ||||
| 	err := ws.c.WriteMessage(websocket.BinaryMessage, p) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	return len(p), nil | ||||
| } | ||||
|  | ||||
| // Close signals the underlying websocket conn to close. | ||||
| func (ws *wsConn) Close() error { | ||||
| 	return ws.Conn.Close() | ||||
| } | ||||
							
								
								
									
										114
									
								
								listeners/websocket_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								listeners/websocket_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewWebsocket(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| 	require.Equal(t, testAddr, l.address) | ||||
| } | ||||
|  | ||||
| func TestWebsocketID(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func TestWebsocketAddress(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	require.Equal(t, testAddr, l.Address()) | ||||
| } | ||||
|  | ||||
| func TestWebsocketProtocol(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	require.Equal(t, "ws", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestWebsocketProtocoTLS(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
| 	require.Equal(t, "wss", l.Protocol()) | ||||
| } | ||||
|  | ||||
| func TestWebsockeInit(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	require.Nil(t, l.listen) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, l.listen) | ||||
| } | ||||
|  | ||||
| func TestWebsocketServeAndClose(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	l.Init(nil) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
|  | ||||
| 	require.True(t, closed) | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestWebsocketServeTLSAndClose(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, &Config{ | ||||
| 		TLSConfig: tlsConfigBasic, | ||||
| 	}) | ||||
| 	err := l.Init(nil) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
|  | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
| 	require.Equal(t, true, closed) | ||||
| } | ||||
|  | ||||
| func TestWebsocketUpgrade(t *testing.T) { | ||||
| 	l := NewWebsocket("t1", testAddr, nil) | ||||
| 	l.Init(nil) | ||||
|  | ||||
| 	e := make(chan bool) | ||||
| 	l.establish = func(id string, c net.Conn) error { | ||||
| 		e <- true | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	s := httptest.NewServer(http.HandlerFunc(l.handler)) | ||||
| 	ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, true, <-e) | ||||
|  | ||||
| 	s.Close() | ||||
| 	ws.Close() | ||||
| } | ||||
| @@ -1,12 +1,18 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
| 
 | ||||
| package packets | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"io" | ||||
| 	"unicode/utf8" | ||||
| 	"unsafe" | ||||
| ) | ||||
| 
 | ||||
| // bytesToString provides a zero-alloc, no-copy byte to string conversion. | ||||
| // bytesToString provides a zero-alloc no-copy byte to string conversion. | ||||
| // via https://github.com/golang/go/issues/25484#issuecomment-391415660 | ||||
| func bytesToString(bs []byte) string { | ||||
| 	return *(*string)(unsafe.Pointer(&bs)) | ||||
| @@ -15,12 +21,21 @@ func bytesToString(bs []byte) string { | ||||
| // decodeUint16 extracts the value of two bytes from a byte array. | ||||
| func decodeUint16(buf []byte, offset int) (uint16, int, error) { | ||||
| 	if len(buf) < offset+2 { | ||||
| 		return 0, 0, ErrOffsetUintOutOfRange | ||||
| 		return 0, 0, ErrMalformedOffsetUintOutOfRange | ||||
| 	} | ||||
| 
 | ||||
| 	return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil | ||||
| } | ||||
| 
 | ||||
| // decodeUint32 extracts the value of four bytes from a byte array. | ||||
| func decodeUint32(buf []byte, offset int) (uint32, int, error) { | ||||
| 	if len(buf) < offset+4 { | ||||
| 		return 0, 0, ErrMalformedOffsetUintOutOfRange | ||||
| 	} | ||||
| 
 | ||||
| 	return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil | ||||
| } | ||||
| 
 | ||||
| // decodeString extracts a string from a byte array, beginning at an offset. | ||||
| func decodeString(buf []byte, offset int) (string, int, error) { | ||||
| 	b, n, err := decodeBytes(buf, offset) | ||||
| @@ -28,22 +43,27 @@ func decodeString(buf []byte, offset int) (string, int, error) { | ||||
| 		return "", 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5] | ||||
| 		return "", 0, ErrMalformedInvalidUTF8 | ||||
| 	} | ||||
| 
 | ||||
| 	return bytesToString(b), n, nil | ||||
| } | ||||
| 
 | ||||
| // validUTF8 checks if the byte array contains valid UTF-8 characters. | ||||
| func validUTF8(b []byte) bool { | ||||
| 	return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2] | ||||
| } | ||||
| 
 | ||||
| // decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads. | ||||
| func decodeBytes(buf []byte, offset int) ([]byte, int, error) { | ||||
| 	length, next, err := decodeUint16(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return make([]byte, 0, 0), 0, err | ||||
| 		return make([]byte, 0), 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	if next+int(length) > len(buf) { | ||||
| 		return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange | ||||
| 	} | ||||
| 
 | ||||
| 	if !validUTF8(buf[next : next+int(length)]) { | ||||
| 		return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8 | ||||
| 		return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange | ||||
| 	} | ||||
| 
 | ||||
| 	return buf[next : next+int(length)], next + int(length), nil | ||||
| @@ -52,7 +72,7 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) { | ||||
| // decodeByte extracts the value of a byte from a byte array. | ||||
| func decodeByte(buf []byte, offset int) (byte, int, error) { | ||||
| 	if len(buf) <= offset { | ||||
| 		return 0, 0, ErrOffsetByteOutOfRange | ||||
| 		return 0, 0, ErrMalformedOffsetByteOutOfRange | ||||
| 	} | ||||
| 	return buf[offset], offset + 1, nil | ||||
| } | ||||
| @@ -60,7 +80,7 @@ func decodeByte(buf []byte, offset int) (byte, int, error) { | ||||
| // decodeByteBool extracts the value of a byte from a byte array and returns a bool. | ||||
| func decodeByteBool(buf []byte, offset int) (bool, int, error) { | ||||
| 	if len(buf) <= offset { | ||||
| 		return false, 0, ErrOffsetBoolOutOfRange | ||||
| 		return false, 0, ErrMalformedOffsetBoolOutOfRange | ||||
| 	} | ||||
| 	return 1&buf[offset] > 0, offset + 1, nil | ||||
| } | ||||
| @@ -75,7 +95,7 @@ func encodeBool(b bool) byte { | ||||
| 
 | ||||
| // encodeBytes encodes a byte array to a byte array. Used primarily for message payloads. | ||||
| func encodeBytes(val []byte) []byte { | ||||
| 	// In many circumstances the number of bytes being encoded is small. | ||||
| 	// In most circumstances the number of bytes being encoded is small. | ||||
| 	// Setting the cap to a low amount allows us to account for those without | ||||
| 	// triggering allocation growth on append unless we need to. | ||||
| 	buf := make([]byte, 2, 32) | ||||
| @@ -90,6 +110,13 @@ func encodeUint16(val uint16) []byte { | ||||
| 	return buf | ||||
| } | ||||
| 
 | ||||
| // encodeUint32 encodes a uint16 value to a byte array. | ||||
| func encodeUint32(val uint32) []byte { | ||||
| 	buf := make([]byte, 4) | ||||
| 	binary.BigEndian.PutUint32(buf, val) | ||||
| 	return buf | ||||
| } | ||||
| 
 | ||||
| // encodeString encodes a string to a byte array. | ||||
| func encodeString(val string) []byte { | ||||
| 	// Like encodeBytes, we set the cap to a small number to avoid | ||||
| @@ -99,16 +126,47 @@ func encodeString(val string) []byte { | ||||
| 	return append(buf, []byte(val)...) | ||||
| } | ||||
| 
 | ||||
| // validUTF8 checks if the byte array contains valid UTF-8 characters, specifically | ||||
| // conforming to the MQTT specification requirements. | ||||
| func validUTF8(b []byte) bool { | ||||
| 	// [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8... | ||||
| 	if !utf8.Valid(b) { | ||||
| 		return false | ||||
| // encodeLength writes length bits for the header. | ||||
| func encodeLength(b *bytes.Buffer, length int64) { | ||||
| 	// 1.5.5 Variable Byte Integer encode non-normative | ||||
| 	// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 | ||||
| 	for { | ||||
| 		eb := byte(length % 128) | ||||
| 		length /= 128 | ||||
| 		if length > 0 { | ||||
| 			eb |= 0x80 | ||||
| 		} | ||||
| 		b.WriteByte(eb) | ||||
| 		if length == 0 { | ||||
| 			break // [MQTT-1.5.5-1] | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func DecodeLength(b io.ByteReader) (n, bu int, err error) { | ||||
| 	// see 1.5.5 Variable Byte Integer decode non-normative | ||||
| 	// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 | ||||
| 	var multiplier uint32 | ||||
| 	var value uint32 | ||||
| 	bu = 1 | ||||
| 	for { | ||||
| 		eb, err := b.ReadByte() | ||||
| 		if err != nil { | ||||
| 			return 0, bu, err | ||||
| 		} | ||||
| 
 | ||||
| 		value |= uint32(eb&127) << multiplier | ||||
| 		if value > 268435455 { | ||||
| 			return 0, bu, ErrMalformedVariableByteInteger | ||||
| 		} | ||||
| 
 | ||||
| 		if (eb & 128) == 0 { | ||||
| 			break | ||||
| 		} | ||||
| 
 | ||||
| 		multiplier += 7 | ||||
| 		bu++ | ||||
| 	} | ||||
| 
 | ||||
| 	// [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000... | ||||
| 	// ... | ||||
| 	return true | ||||
| 
 | ||||
| 	return int(value), bu, nil | ||||
| } | ||||
							
								
								
									
										422
									
								
								packets/codec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										422
									
								
								packets/codec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,422 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestBytesToString(t *testing.T) { | ||||
| 	b := []byte{'a', 'b', 'c'} | ||||
| 	require.Equal(t, "abc", bytesToString(b)) | ||||
| } | ||||
|  | ||||
| func TestDecodeString(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		name       string | ||||
| 		rawBytes   []byte | ||||
| 		result     string | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			offset:   0, | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   "a/b/c/d", | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset: 14, | ||||
| 			rawBytes: []byte{ | ||||
| 				Connect << 4, 17, // Fixed header | ||||
| 				0, 6, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name | ||||
| 				3,     // Protocol Version | ||||
| 				0,     // Packet Flags | ||||
| 				0, 30, // Keepalive | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'h', 'e', 'y', // Client ID "zen"}, | ||||
| 			}, | ||||
| 			result: "hey", | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:   2, | ||||
| 			rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, | ||||
| 			result:   "1/2/3/4/a/b/c/d/e/^/@/!", | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:   0, | ||||
| 			rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, | ||||
| 			result:   "x/y/z", | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     0, | ||||
| 			rawBytes:   []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, | ||||
| 			shouldFail: ErrMalformedOffsetBytesOutOfRange, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     5, | ||||
| 			rawBytes:   []byte{0, 7, 97, 47, 98, 47, 'x'}, | ||||
| 			shouldFail: ErrMalformedOffsetBytesOutOfRange, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     9, | ||||
| 			rawBytes:   []byte{0, 7, 97, 47, 98, 47, 'y'}, | ||||
| 			shouldFail: ErrMalformedOffsetUintOutOfRange, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset: 17, | ||||
| 			rawBytes: []byte{ | ||||
| 				Connect << 4, 0, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				4,     // Protocol Version | ||||
| 				0,     // Flags | ||||
| 				0, 20, // Keepalive | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 				0, 6, // Will Topic - MSB+LSB | ||||
| 				'l', | ||||
| 			}, | ||||
| 			shouldFail: ErrMalformedOffsetBytesOutOfRange, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     0, | ||||
| 			rawBytes:   []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100}, | ||||
| 			shouldFail: ErrMalformedInvalidUTF8, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, _, err := decodeString(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3] | ||||
| 	result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "\ufeff", result) | ||||
| } | ||||
|  | ||||
| func TestDecodeBytes(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     []uint8 | ||||
| 		next       int | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session) | ||||
| 			result:   []byte{0x4d, 0x51, 0x54, 0x54}, | ||||
| 			next:     6, | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start | ||||
| 			result:   []byte{0x4d, 0x51, 0x54, 0x54}, | ||||
| 			next:     6, | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 81}, | ||||
| 			offset:     0, | ||||
| 			shouldFail: ErrMalformedOffsetBytesOutOfRange, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 81}, | ||||
| 			offset:     8, | ||||
| 			shouldFail: ErrMalformedOffsetUintOutOfRange, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeByte(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     uint8 | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes | ||||
| 			result:   uint8(0x00), | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x04), | ||||
| 			offset:   1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x4d), | ||||
| 			offset:   2, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x51), | ||||
| 			offset:   3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 80, 82, 84}, | ||||
| 			offset:     8, | ||||
| 			shouldFail: ErrMalformedOffsetByteOutOfRange, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 			require.Equal(t, i+1, offset) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeUint16(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     uint16 | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   uint16(0x07), | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   uint16(0x761), | ||||
| 			offset:   1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 7, 255, 47}, | ||||
| 			offset:     8, | ||||
| 			shouldFail: ErrMalformedOffsetUintOutOfRange, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 			require.Equal(t, i+2, offset) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeUint32(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     uint32 | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 0, 0, 7, 8}, | ||||
| 			result:   uint32(7), | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 0, 1, 226, 64, 8}, | ||||
| 			result:   uint32(123456), | ||||
| 			offset:   1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 7, 255, 47}, | ||||
| 			offset:     8, | ||||
| 			shouldFail: ErrMalformedOffsetUintOutOfRange, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 			require.Equal(t, i+4, offset) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeByteBool(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     bool | ||||
| 		offset     int | ||||
| 		shouldFail error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0x00, 0x00}, | ||||
| 			result:   false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0x01, 0x00}, | ||||
| 			result:   true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0x01, 0x00}, | ||||
| 			offset:     5, | ||||
| 			shouldFail: ErrMalformedOffsetBoolOutOfRange, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		t.Run(fmt.Sprint(i), func(t *testing.T) { | ||||
| 			result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) | ||||
| 			if wanted.shouldFail != nil { | ||||
| 				require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			require.NoError(t, err) | ||||
| 			require.Equal(t, wanted.result, result) | ||||
| 			require.Equal(t, 1, offset) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeLength(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{0x78}) | ||||
| 	n, bu, err := DecodeLength(b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 120, n) | ||||
| 	require.Equal(t, 1, bu) | ||||
|  | ||||
| 	b = bytes.NewBuffer([]byte{255, 255, 255, 127}) | ||||
| 	n, bu, err = DecodeLength(b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 268435455, n) | ||||
| 	require.Equal(t, 4, bu) | ||||
| } | ||||
|  | ||||
| func TestDecodeLengthErrors(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	_, _, err := DecodeLength(b) | ||||
| 	require.Error(t, err) | ||||
|  | ||||
| 	b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}) | ||||
| 	_, _, err = DecodeLength(b) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, ErrMalformedVariableByteInteger) | ||||
| } | ||||
|  | ||||
| func TestEncodeBool(t *testing.T) { | ||||
| 	result := encodeBool(true) | ||||
| 	require.Equal(t, byte(1), result) | ||||
|  | ||||
| 	result = encodeBool(false) | ||||
| 	require.Equal(t, byte(0), result) | ||||
|  | ||||
| 	// Check failure. | ||||
| 	result = encodeBool(false) | ||||
| 	require.NotEqual(t, byte(1), result) | ||||
| } | ||||
|  | ||||
| func TestEncodeBytes(t *testing.T) { | ||||
| 	result := encodeBytes([]byte("testing")) | ||||
| 	require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result) | ||||
|  | ||||
| 	result = encodeBytes([]byte("testing")) | ||||
| 	require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result) | ||||
| } | ||||
|  | ||||
| func TestEncodeUint16(t *testing.T) { | ||||
| 	result := encodeUint16(0) | ||||
| 	require.Equal(t, []byte{0x00, 0x00}, result) | ||||
|  | ||||
| 	result = encodeUint16(32767) | ||||
| 	require.Equal(t, []byte{0x7f, 0xff}, result) | ||||
|  | ||||
| 	result = encodeUint16(65535) | ||||
| 	require.Equal(t, []byte{0xff, 0xff}, result) | ||||
| } | ||||
|  | ||||
| func TestEncodeUint32(t *testing.T) { | ||||
| 	result := encodeUint32(7) | ||||
| 	require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result) | ||||
|  | ||||
| 	result = encodeUint32(32767) | ||||
| 	require.Equal(t, []byte{0, 0, 127, 255}, result) | ||||
|  | ||||
| 	result = encodeUint32(math.MaxUint32) | ||||
| 	require.Equal(t, []byte{255, 255, 255, 255}, result) | ||||
| } | ||||
|  | ||||
| func TestEncodeString(t *testing.T) { | ||||
| 	result := encodeString("testing") | ||||
| 	require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result) | ||||
|  | ||||
| 	result = encodeString("") | ||||
| 	require.Equal(t, []uint8{0x00, 0x00}, result) | ||||
|  | ||||
| 	result = encodeString("a") | ||||
| 	require.Equal(t, []uint8{0x00, 0x01, 0x61}, result) | ||||
|  | ||||
| 	result = encodeString("b") | ||||
| 	require.NotEqual(t, []uint8{0x00, 0x00}, result) | ||||
| } | ||||
|  | ||||
| func TestEncodeLength(t *testing.T) { | ||||
| 	b := new(bytes.Buffer) | ||||
| 	encodeLength(b, 120) | ||||
| 	require.Equal(t, []byte{0x78}, b.Bytes()) | ||||
|  | ||||
| 	b = new(bytes.Buffer) | ||||
| 	encodeLength(b, math.MaxInt64) | ||||
| 	require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestValidUTF8(t *testing.T) { | ||||
| 	require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67})) | ||||
| 	require.False(t, validUTF8([]byte{0xff, 0xff})) | ||||
| 	require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74})) | ||||
| } | ||||
							
								
								
									
										127
									
								
								packets/codes.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								packets/codes.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| // Code contains a reason code and reason string for a response. | ||||
| type Code struct { | ||||
| 	Reason string | ||||
| 	Code   byte | ||||
| } | ||||
|  | ||||
| // String returns the readable reason for a code. | ||||
| func (c Code) String() string { | ||||
| 	return c.Reason | ||||
| } | ||||
|  | ||||
| // Error returns the readable reason for a code. | ||||
| func (c Code) Error() string { | ||||
| 	return c.Reason | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// QosCodes indicicates the reason codes for each Qos byte. | ||||
| 	QosCodes = map[byte]Code{ | ||||
| 		0: CodeGrantedQos0, | ||||
| 		1: CodeGrantedQos1, | ||||
| 		2: CodeGrantedQos2, | ||||
| 	} | ||||
|  | ||||
| 	CodeSuccess                               = Code{Code: 0x00, Reason: "success"} | ||||
| 	CodeDisconnect                            = Code{Code: 0x00, Reason: "disconnected"} | ||||
| 	CodeGrantedQos0                           = Code{Code: 0x00, Reason: "granted qos 0"} | ||||
| 	CodeGrantedQos1                           = Code{Code: 0x01, Reason: "granted qos 1"} | ||||
| 	CodeGrantedQos2                           = Code{Code: 0x02, Reason: "granted qos 2"} | ||||
| 	CodeDisconnectWillMessage                 = Code{Code: 0x04, Reason: "disconnect with will message"} | ||||
| 	CodeNoMatchingSubscribers                 = Code{Code: 0x10, Reason: "no matching subscribers"} | ||||
| 	CodeNoSubscriptionExisted                 = Code{Code: 0x11, Reason: "no subscription existed"} | ||||
| 	CodeContinueAuthentication                = Code{Code: 0x18, Reason: "continue authentication"} | ||||
| 	CodeReAuthenticate                        = Code{Code: 0x19, Reason: "re-authenticate"} | ||||
| 	ErrUnspecifiedError                       = Code{Code: 0x80, Reason: "unspecified error"} | ||||
| 	ErrMalformedPacket                        = Code{Code: 0x81, Reason: "malformed packet"} | ||||
| 	ErrMalformedProtocolName                  = Code{Code: 0x81, Reason: "malformed packet: protocol name"} | ||||
| 	ErrMalformedProtocolVersion               = Code{Code: 0x81, Reason: "malformed packet: protocol version"} | ||||
| 	ErrMalformedFlags                         = Code{Code: 0x81, Reason: "malformed packet: flags"} | ||||
| 	ErrMalformedKeepalive                     = Code{Code: 0x81, Reason: "malformed packet: keepalive"} | ||||
| 	ErrMalformedPacketID                      = Code{Code: 0x81, Reason: "malformed packet: packet identifier"} | ||||
| 	ErrMalformedTopic                         = Code{Code: 0x81, Reason: "malformed packet: topic"} | ||||
| 	ErrMalformedWillTopic                     = Code{Code: 0x81, Reason: "malformed packet: will topic"} | ||||
| 	ErrMalformedWillPayload                   = Code{Code: 0x81, Reason: "malformed packet: will message"} | ||||
| 	ErrMalformedUsername                      = Code{Code: 0x81, Reason: "malformed packet: username"} | ||||
| 	ErrMalformedPassword                      = Code{Code: 0x81, Reason: "malformed packet: password"} | ||||
| 	ErrMalformedQos                           = Code{Code: 0x81, Reason: "malformed packet: qos"} | ||||
| 	ErrMalformedOffsetUintOutOfRange          = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"} | ||||
| 	ErrMalformedOffsetBytesOutOfRange         = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"} | ||||
| 	ErrMalformedOffsetByteOutOfRange          = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"} | ||||
| 	ErrMalformedOffsetBoolOutOfRange          = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"} | ||||
| 	ErrMalformedInvalidUTF8                   = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"} | ||||
| 	ErrMalformedVariableByteInteger           = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"} | ||||
| 	ErrMalformedBadProperty                   = Code{Code: 0x81, Reason: "malformed packet: unknown property"} | ||||
| 	ErrMalformedProperties                    = Code{Code: 0x81, Reason: "malformed packet: properties"} | ||||
| 	ErrMalformedWillProperties                = Code{Code: 0x81, Reason: "malformed packet: will properties"} | ||||
| 	ErrMalformedSessionPresent                = Code{Code: 0x81, Reason: "malformed packet: session present"} | ||||
| 	ErrMalformedReasonCode                    = Code{Code: 0x81, Reason: "malformed packet: reason code"} | ||||
| 	ErrProtocolViolation                      = Code{Code: 0x82, Reason: "protocol violation"} | ||||
| 	ErrProtocolViolationProtocolName          = Code{Code: 0x82, Reason: "protocol violation: protocol name"} | ||||
| 	ErrProtocolViolationProtocolVersion       = Code{Code: 0x82, Reason: "protocol violation: protocol version"} | ||||
| 	ErrProtocolViolationReservedBit           = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"} | ||||
| 	ErrProtocolViolationFlagNoUsername        = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"} | ||||
| 	ErrProtocolViolationFlagNoPassword        = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"} | ||||
| 	ErrProtocolViolationUsernameNoFlag        = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} | ||||
| 	ErrProtocolViolationPasswordNoFlag        = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} | ||||
| 	ErrProtocolViolationPasswordTooLong       = Code{Code: 0x82, Reason: "protocol violation: password too long"} | ||||
| 	ErrProtocolViolationUsernameTooLong       = Code{Code: 0x82, Reason: "protocol violation: username too long"} | ||||
| 	ErrProtocolViolationNoPacketID            = Code{Code: 0x82, Reason: "protocol violation: missing packet id"} | ||||
| 	ErrProtocolViolationSurplusPacketID       = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"} | ||||
| 	ErrProtocolViolationQosOutOfRange         = Code{Code: 0x82, Reason: "protocol violation: qos out of range"} | ||||
| 	ErrProtocolViolationSecondConnect         = Code{Code: 0x82, Reason: "protocol violation: second connect packet"} | ||||
| 	ErrProtocolViolationZeroNonZeroExpiry     = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"} | ||||
| 	ErrProtocolViolationRequireFirstConnect   = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"} | ||||
| 	ErrProtocolViolationWillFlagNoPayload     = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"} | ||||
| 	ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"} | ||||
| 	ErrProtocolViolationSurplusWildcard       = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"} | ||||
| 	ErrProtocolViolationSurplusSubID          = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"} | ||||
| 	ErrProtocolViolationInvalidTopic          = Code{Code: 0x82, Reason: "protocol violation: invalid topic"} | ||||
| 	ErrProtocolViolationInvalidSharedNoLocal  = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"} | ||||
| 	ErrProtocolViolationNoFilters             = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"} | ||||
| 	ErrProtocolViolationInvalidReason         = Code{Code: 0x82, Reason: "protocol violation: invalid reason"} | ||||
| 	ErrProtocolViolationOversizeSubID         = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"} | ||||
| 	ErrProtocolViolationDupNoQos              = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"} | ||||
| 	ErrProtocolViolationUnsupportedProperty   = Code{Code: 0x82, Reason: "protocol violation: unsupported property"} | ||||
| 	ErrProtocolViolationNoTopic               = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"} | ||||
| 	ErrImplementationSpecificError            = Code{Code: 0x83, Reason: "implementation specific error"} | ||||
| 	ErrRejectPacket                           = Code{Code: 0x83, Reason: "packet rejected"} | ||||
| 	ErrUnsupportedProtocolVersion             = Code{Code: 0x84, Reason: "unsupported protocol version"} | ||||
| 	ErrClientIdentifierNotValid               = Code{Code: 0x85, Reason: "client identifier not valid"} | ||||
| 	ErrClientIdentifierTooLong                = Code{Code: 0x85, Reason: "client identifier too long"} | ||||
| 	ErrBadUsernameOrPassword                  = Code{Code: 0x86, Reason: "bad username or password"} | ||||
| 	ErrNotAuthorized                          = Code{Code: 0x87, Reason: "not authorized"} | ||||
| 	ErrServerUnavailable                      = Code{Code: 0x88, Reason: "server unavailable"} | ||||
| 	ErrServerBusy                             = Code{Code: 0x89, Reason: "server busy"} | ||||
| 	ErrBanned                                 = Code{Code: 0x8A, Reason: "banned"} | ||||
| 	ErrServerShuttingDown                     = Code{Code: 0x8B, Reason: "server shutting down"} | ||||
| 	ErrBadAuthenticationMethod                = Code{Code: 0x8C, Reason: "bad authentication method"} | ||||
| 	ErrKeepAliveTimeout                       = Code{Code: 0x8D, Reason: "keep alive timeout"} | ||||
| 	ErrSessionTakenOver                       = Code{Code: 0x8E, Reason: "session takeover"} | ||||
| 	ErrTopicFilterInvalid                     = Code{Code: 0x8F, Reason: "topic filter invalid"} | ||||
| 	ErrTopicNameInvalid                       = Code{Code: 0x90, Reason: "topic name invalid"} | ||||
| 	ErrPacketIdentifierInUse                  = Code{Code: 0x91, Reason: "packet identifier in use"} | ||||
| 	ErrPacketIdentifierNotFound               = Code{Code: 0x92, Reason: "packet identifier not found"} | ||||
| 	ErrReceiveMaximum                         = Code{Code: 0x93, Reason: "receive maximum exceeded"} | ||||
| 	ErrTopicAliasInvalid                      = Code{Code: 0x94, Reason: "topic alias invalid"} | ||||
| 	ErrPacketTooLarge                         = Code{Code: 0x95, Reason: "packet too large"} | ||||
| 	ErrMessageRateTooHigh                     = Code{Code: 0x96, Reason: "message rate too high"} | ||||
| 	ErrQuotaExceeded                          = Code{Code: 0x97, Reason: "quota exceeded"} | ||||
| 	ErrAdministrativeAction                   = Code{Code: 0x98, Reason: "administrative action"} | ||||
| 	ErrPayloadFormatInvalid                   = Code{Code: 0x99, Reason: "payload format invalid"} | ||||
| 	ErrRetainNotSupported                     = Code{Code: 0x9A, Reason: "retain not supported"} | ||||
| 	ErrQosNotSupported                        = Code{Code: 0x9B, Reason: "qos not supported"} | ||||
| 	ErrUseAnotherServer                       = Code{Code: 0x9C, Reason: "use another server"} | ||||
| 	ErrServerMoved                            = Code{Code: 0x9D, Reason: "server moved"} | ||||
| 	ErrSharedSubscriptionsNotSupported        = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"} | ||||
| 	ErrConnectionRateExceeded                 = Code{Code: 0x9F, Reason: "connection rate exceeded"} | ||||
| 	ErrMaxConnectTime                         = Code{Code: 0xA0, Reason: "maximum connect time"} | ||||
| 	ErrSubscriptionIdentifiersNotSupported    = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} | ||||
| 	ErrWildcardSubscriptionsNotSupported      = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} | ||||
| ) | ||||
							
								
								
									
										29
									
								
								packets/codes_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								packets/codes_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestCodesString(t *testing.T) { | ||||
| 	c := Code{ | ||||
| 		Reason: "test", | ||||
| 		Code:   0x1, | ||||
| 	} | ||||
|  | ||||
| 	require.Equal(t, "test", c.String()) | ||||
| } | ||||
|  | ||||
| func TestCodesErrorr(t *testing.T) { | ||||
| 	c := Code{ | ||||
| 		Reason: "error", | ||||
| 		Code:   0x1, | ||||
| 	} | ||||
|  | ||||
| 	require.Equal(t, "error", error(c).Error()) | ||||
| } | ||||
							
								
								
									
										63
									
								
								packets/fixedheader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								packets/fixedheader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| ) | ||||
|  | ||||
| // FixedHeader contains the values of the fixed header portion of the MQTT packet. | ||||
| type FixedHeader struct { | ||||
| 	Remaining int  `json:"remaining"` // the number of remaining bytes in the payload. | ||||
| 	Type      byte `json:"type"`      // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). | ||||
| 	Qos       byte `json:"qos"`       // indicates the quality of service expected. | ||||
| 	Dup       bool `json:"dup"`       // indicates if the packet was already sent at an earlier time. | ||||
| 	Retain    bool `json:"retain"`    // whether the message should be retained. | ||||
| } | ||||
|  | ||||
| // Encode encodes the FixedHeader and returns a bytes buffer. | ||||
| func (fh *FixedHeader) Encode(buf *bytes.Buffer) { | ||||
| 	buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) | ||||
| 	encodeLength(buf, int64(fh.Remaining)) | ||||
| } | ||||
|  | ||||
| // Decode extracts the specification bits from the header byte. | ||||
| func (fh *FixedHeader) Decode(hb byte) error { | ||||
| 	fh.Type = hb >> 4 // Get the message type from the first 4 bytes. | ||||
|  | ||||
| 	switch fh.Type { | ||||
| 	case Publish: | ||||
| 		if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 { | ||||
| 			return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4] | ||||
| 		} | ||||
|  | ||||
| 		fh.Dup = (hb>>3)&0x01 > 0 // is duplicate | ||||
| 		fh.Qos = (hb >> 1) & 0x03 // qos flag | ||||
| 		fh.Retain = hb&0x01 > 0   // is retain flag | ||||
| 	case Pubrel: | ||||
| 		fallthrough | ||||
| 	case Subscribe: | ||||
| 		fallthrough | ||||
| 	case Unsubscribe: | ||||
| 		if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1] | ||||
| 			return ErrMalformedFlags | ||||
| 		} | ||||
|  | ||||
| 		fh.Qos = (hb >> 1) & 0x03 | ||||
| 	default: | ||||
| 		if (hb>>0)&0x01 != 0 || | ||||
| 			(hb>>1)&0x01 != 0 || | ||||
| 			(hb>>2)&0x01 != 0 || | ||||
| 			(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1] | ||||
| 			return ErrMalformedFlags | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if fh.Qos == 0 && fh.Dup { | ||||
| 		return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2] | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										237
									
								
								packets/fixedheader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										237
									
								
								packets/fixedheader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,237 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| type fixedHeaderTable struct { | ||||
| 	desc        string | ||||
| 	rawBytes    []byte | ||||
| 	header      FixedHeader | ||||
| 	packetError bool | ||||
| 	expect      error | ||||
| } | ||||
|  | ||||
| var fixedHeaderExpected = []fixedHeaderTable{ | ||||
| 	{ | ||||
| 		desc:     "connect", | ||||
| 		rawBytes: []byte{Connect << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "connack", | ||||
| 		rawBytes: []byte{Connack << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish", | ||||
| 		rawBytes: []byte{Publish << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish qos 1", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish qos 1 retain", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish qos 2", | ||||
| 		rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish qos 2 retain", | ||||
| 		rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish dup qos 0", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 		expect:   ErrProtocolViolationDupNoQos, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish dup qos 0 retain", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0}, | ||||
| 		expect:   ErrProtocolViolationDupNoQos, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish dup qos 1 retain", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "publish dup qos 2 retain", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "puback", | ||||
| 		rawBytes: []byte{Puback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "pubrec", | ||||
| 		rawBytes: []byte{Pubrec << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "pubrel", | ||||
| 		rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "pubcomp", | ||||
| 		rawBytes: []byte{Pubcomp << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "subscribe", | ||||
| 		rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "suback", | ||||
| 		rawBytes: []byte{Suback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "unsubscribe", | ||||
| 		rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "unsuback", | ||||
| 		rawBytes: []byte{Unsuback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "pingreq", | ||||
| 		rawBytes: []byte{Pingreq << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "pingresp", | ||||
| 		rawBytes: []byte{Pingresp << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "disconnect", | ||||
| 		rawBytes: []byte{Disconnect << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "auth", | ||||
| 		rawBytes: []byte{Auth << 4, 0x00}, | ||||
| 		header:   FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 	}, | ||||
|  | ||||
| 	// remaining length | ||||
| 	{ | ||||
| 		desc:     "remaining length 10", | ||||
| 		rawBytes: []byte{Publish << 4, 0x0a}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "remaining length 512", | ||||
| 		rawBytes: []byte{Publish << 4, 0x80, 0x04}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "remaining length 978", | ||||
| 		rawBytes: []byte{Publish << 4, 0xd2, 0x07}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "remaining length 20202", | ||||
| 		rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, | ||||
| 		header:   FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:        "remaining length oversize", | ||||
| 		rawBytes:    []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, | ||||
| 		header:      FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333}, | ||||
| 		packetError: true, | ||||
| 	}, | ||||
|  | ||||
| 	// Invalid flags for packet | ||||
| 	{ | ||||
| 		desc:     "invalid type dup is true", | ||||
| 		rawBytes: []byte{Connect<<4 | 1<<3, 0x00}, | ||||
| 		header:   FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0}, | ||||
| 		expect:   ErrMalformedFlags, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "invalid type qos is 1", | ||||
| 		rawBytes: []byte{Connect<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0}, | ||||
| 		expect:   ErrMalformedFlags, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "invalid type retain is true", | ||||
| 		rawBytes: []byte{Connect<<4 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0}, | ||||
| 		expect:   ErrMalformedFlags, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "invalid publish qos bits 1 + 2 set", | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00}, | ||||
| 		header:   FixedHeader{Type: Publish}, | ||||
| 		expect:   ErrProtocolViolationQosOutOfRange, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "invalid pubrel bits 3,2,1,0 should be 0,0,1,0", | ||||
| 		rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00}, | ||||
| 		header:   FixedHeader{Type: Pubrel, Qos: 1}, | ||||
| 		expect:   ErrMalformedFlags, | ||||
| 	}, | ||||
| 	{ | ||||
| 		desc:     "invalid subscribe bits 3,2,1,0 should be 0,0,1,0", | ||||
| 		rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00}, | ||||
| 		header:   FixedHeader{Type: Subscribe, Qos: 1}, | ||||
| 		expect:   ErrMalformedFlags, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| func TestFixedHeaderEncode(t *testing.T) { | ||||
| 	for _, wanted := range fixedHeaderExpected { | ||||
| 		t.Run(wanted.desc, func(t *testing.T) { | ||||
| 			buf := new(bytes.Buffer) | ||||
| 			wanted.header.Encode(buf) | ||||
| 			if wanted.expect == nil { | ||||
| 				require.Equal(t, len(wanted.rawBytes), len(buf.Bytes())) | ||||
| 				require.EqualValues(t, wanted.rawBytes, buf.Bytes()) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestFixedHeaderDecode(t *testing.T) { | ||||
| 	for _, wanted := range fixedHeaderExpected { | ||||
| 		t.Run(wanted.desc, func(t *testing.T) { | ||||
| 			fh := new(FixedHeader) | ||||
| 			err := fh.Decode(wanted.rawBytes[0]) | ||||
| 			if wanted.expect != nil { | ||||
| 				require.Equal(t, wanted.expect, err) | ||||
| 			} else { | ||||
| 				require.NoError(t, err) | ||||
| 				require.Equal(t, wanted.header.Type, fh.Type) | ||||
| 				require.Equal(t, wanted.header.Dup, fh.Dup) | ||||
| 				require.Equal(t, wanted.header.Qos, fh.Qos) | ||||
| 				require.Equal(t, wanted.header.Retain, fh.Retain) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										1141
									
								
								packets/packets.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1141
									
								
								packets/packets.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										502
									
								
								packets/packets_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								packets/packets_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,502 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 J. Blake / mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/jinzhu/copier" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| const pkInfo = "packet type %v, %s" | ||||
|  | ||||
| var packetList = []byte{ | ||||
| 	Connect, | ||||
| 	Connack, | ||||
| 	Publish, | ||||
| 	Puback, | ||||
| 	Pubrec, | ||||
| 	Pubrel, | ||||
| 	Pubcomp, | ||||
| 	Subscribe, | ||||
| 	Suback, | ||||
| 	Unsubscribe, | ||||
| 	Unsuback, | ||||
| 	Pingreq, | ||||
| 	Pingresp, | ||||
| 	Disconnect, | ||||
| 	Auth, | ||||
| } | ||||
|  | ||||
| var pkTable = []TPacketCase{ | ||||
| 	TPacketData[Connect].Get(TConnectMqtt311), | ||||
| 	TPacketData[Connect].Get(TConnectMqtt5), | ||||
| 	TPacketData[Connect].Get(TConnectUserPassLWT), | ||||
| 	TPacketData[Connack].Get(TConnackAcceptedMqtt5), | ||||
| 	TPacketData[Connack].Get(TConnackAcceptedNoSession), | ||||
| 	TPacketData[Publish].Get(TPublishBasic), | ||||
| 	TPacketData[Publish].Get(TPublishMqtt5), | ||||
| 	TPacketData[Puback].Get(TPuback), | ||||
| 	TPacketData[Pubrec].Get(TPubrec), | ||||
| 	TPacketData[Pubrel].Get(TPubrel), | ||||
| 	TPacketData[Pubcomp].Get(TPubcomp), | ||||
| 	TPacketData[Subscribe].Get(TSubscribe), | ||||
| 	TPacketData[Subscribe].Get(TSubscribeMqtt5), | ||||
| 	TPacketData[Suback].Get(TSuback), | ||||
| 	TPacketData[Unsubscribe].Get(TUnsubscribe), | ||||
| 	TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5), | ||||
| 	TPacketData[Pingreq].Get(TPingreq), | ||||
| 	TPacketData[Pingresp].Get(TPingresp), | ||||
| 	TPacketData[Disconnect].Get(TDisconnect), | ||||
| 	TPacketData[Disconnect].Get(TDisconnectMqtt5), | ||||
| } | ||||
|  | ||||
| func TestNewPackets(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	require.NotNil(t, s.internal) | ||||
| } | ||||
|  | ||||
| func TestPacketsAdd(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	s.Add("cl1", Packet{}) | ||||
| 	require.Contains(t, s.internal, "cl1") | ||||
| } | ||||
|  | ||||
| func TestPacketsGet(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	s.Add("cl1", Packet{TopicName: "a1"}) | ||||
| 	s.Add("cl2", Packet{TopicName: "a2"}) | ||||
| 	require.Contains(t, s.internal, "cl1") | ||||
| 	require.Contains(t, s.internal, "cl2") | ||||
|  | ||||
| 	pk, ok := s.Get("cl1") | ||||
| 	require.True(t, ok) | ||||
| 	require.Equal(t, "a1", pk.TopicName) | ||||
| } | ||||
|  | ||||
| func TestPacketsGetAll(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	s.Add("cl1", Packet{TopicName: "a1"}) | ||||
| 	s.Add("cl2", Packet{TopicName: "a2"}) | ||||
| 	s.Add("cl3", Packet{TopicName: "a3"}) | ||||
| 	require.Contains(t, s.internal, "cl1") | ||||
| 	require.Contains(t, s.internal, "cl2") | ||||
| 	require.Contains(t, s.internal, "cl3") | ||||
|  | ||||
| 	subs := s.GetAll() | ||||
| 	require.Len(t, subs, 3) | ||||
| } | ||||
|  | ||||
| func TestPacketsLen(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	s.Add("cl1", Packet{TopicName: "a1"}) | ||||
| 	s.Add("cl2", Packet{TopicName: "a2"}) | ||||
| 	require.Contains(t, s.internal, "cl1") | ||||
| 	require.Contains(t, s.internal, "cl2") | ||||
| 	require.Equal(t, 2, s.Len()) | ||||
| } | ||||
|  | ||||
| func TestSPacketsDelete(t *testing.T) { | ||||
| 	s := NewPackets() | ||||
| 	s.Add("cl1", Packet{TopicName: "a1"}) | ||||
| 	require.Contains(t, s.internal, "cl1") | ||||
|  | ||||
| 	s.Delete("cl1") | ||||
| 	_, ok := s.Get("cl1") | ||||
| 	require.False(t, ok) | ||||
| } | ||||
|  | ||||
| func TestFormatPacketID(t *testing.T) { | ||||
| 	for _, id := range []uint16{0, 7, 0x100, 0xffff} { | ||||
| 		packet := &Packet{PacketID: id} | ||||
| 		require.Equal(t, fmt.Sprint(id), packet.FormatID()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSubscriptionOptionsEncodeDecode(t *testing.T) { | ||||
| 	p := &Subscription{ | ||||
| 		Qos:               2, | ||||
| 		NoLocal:           true, | ||||
| 		RetainAsPublished: true, | ||||
| 		RetainHandling:    2, | ||||
| 	} | ||||
| 	x := new(Subscription) | ||||
| 	x.decode(p.encode()) | ||||
| 	require.Equal(t, *p, *x) | ||||
|  | ||||
| 	p = &Subscription{ | ||||
| 		Qos:               1, | ||||
| 		NoLocal:           false, | ||||
| 		RetainAsPublished: false, | ||||
| 		RetainHandling:    1, | ||||
| 	} | ||||
| 	x = new(Subscription) | ||||
| 	x.decode(p.encode()) | ||||
| 	require.Equal(t, *p, *x) | ||||
| } | ||||
|  | ||||
| func TestPacketEncode(t *testing.T) { | ||||
| 	for _, pkt := range packetList { | ||||
| 		require.Contains(t, TPacketData, pkt) | ||||
| 		for _, wanted := range TPacketData[pkt] { | ||||
| 			t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { | ||||
| 				if !encodeTestOK(wanted) { | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				pk := new(Packet) | ||||
| 				copier.Copy(pk, wanted.Packet) | ||||
| 				require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				pk.Mods.AllowResponseInfo = true | ||||
|  | ||||
| 				buf := new(bytes.Buffer) | ||||
| 				var err error | ||||
| 				switch pkt { | ||||
| 				case Connect: | ||||
| 					err = pk.ConnectEncode(buf) | ||||
| 				case Connack: | ||||
| 					err = pk.ConnackEncode(buf) | ||||
| 				case Publish: | ||||
| 					err = pk.PublishEncode(buf) | ||||
| 				case Puback: | ||||
| 					err = pk.PubackEncode(buf) | ||||
| 				case Pubrec: | ||||
| 					err = pk.PubrecEncode(buf) | ||||
| 				case Pubrel: | ||||
| 					err = pk.PubrelEncode(buf) | ||||
| 				case Pubcomp: | ||||
| 					err = pk.PubcompEncode(buf) | ||||
| 				case Subscribe: | ||||
| 					err = pk.SubscribeEncode(buf) | ||||
| 				case Suback: | ||||
| 					err = pk.SubackEncode(buf) | ||||
| 				case Unsubscribe: | ||||
| 					err = pk.UnsubscribeEncode(buf) | ||||
| 				case Unsuback: | ||||
| 					err = pk.UnsubackEncode(buf) | ||||
| 				case Pingreq: | ||||
| 					err = pk.PingreqEncode(buf) | ||||
| 				case Pingresp: | ||||
| 					err = pk.PingrespEncode(buf) | ||||
| 				case Disconnect: | ||||
| 					err = pk.DisconnectEncode(buf) | ||||
| 				case Auth: | ||||
| 					err = pk.AuthEncode(buf) | ||||
| 				} | ||||
| 				if wanted.Expect != nil { | ||||
| 					require.Error(t, err, pkInfo, pkt, wanted.Desc) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				require.NoError(t, err, pkInfo, pkt, wanted.Desc) | ||||
| 				encoded := buf.Bytes() | ||||
|  | ||||
| 				// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc). | ||||
| 				if len(wanted.ActualBytes) > 0 { | ||||
| 					wanted.RawBytes = wanted.ActualBytes | ||||
| 				} | ||||
| 				require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc) | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPacketDecode(t *testing.T) { | ||||
| 	for _, pkt := range packetList { | ||||
| 		require.Contains(t, TPacketData, pkt) | ||||
| 		for _, wanted := range TPacketData[pkt] { | ||||
| 			t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { | ||||
| 				if !decodeTestOK(wanted) { | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				pk := &Packet{FixedHeader: FixedHeader{Type: pkt}} | ||||
| 				pk.Mods.AllowResponseInfo = true | ||||
| 				pk.FixedHeader.Decode(wanted.RawBytes[0]) | ||||
| 				if len(wanted.RawBytes) > 0 { | ||||
| 					pk.FixedHeader.Remaining = int(wanted.RawBytes[1]) | ||||
| 				} | ||||
|  | ||||
| 				if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 { | ||||
| 					pk.ProtocolVersion = wanted.Packet.ProtocolVersion | ||||
| 				} | ||||
|  | ||||
| 				buf := wanted.RawBytes[2:] | ||||
| 				var err error | ||||
| 				switch pkt { | ||||
| 				case Connect: | ||||
| 					err = pk.ConnectDecode(buf) | ||||
| 				case Connack: | ||||
| 					err = pk.ConnackDecode(buf) | ||||
| 				case Publish: | ||||
| 					err = pk.PublishDecode(buf) | ||||
| 				case Puback: | ||||
| 					err = pk.PubackDecode(buf) | ||||
| 				case Pubrec: | ||||
| 					err = pk.PubrecDecode(buf) | ||||
| 				case Pubrel: | ||||
| 					err = pk.PubrelDecode(buf) | ||||
| 				case Pubcomp: | ||||
| 					err = pk.PubcompDecode(buf) | ||||
| 				case Subscribe: | ||||
| 					err = pk.SubscribeDecode(buf) | ||||
| 				case Suback: | ||||
| 					err = pk.SubackDecode(buf) | ||||
| 				case Unsubscribe: | ||||
| 					err = pk.UnsubscribeDecode(buf) | ||||
| 				case Unsuback: | ||||
| 					err = pk.UnsubackDecode(buf) | ||||
| 				case Pingreq: | ||||
| 					err = pk.PingreqDecode(buf) | ||||
| 				case Pingresp: | ||||
| 					err = pk.PingrespDecode(buf) | ||||
| 				case Disconnect: | ||||
| 					err = pk.DisconnectDecode(buf) | ||||
| 				case Auth: | ||||
| 					err = pk.AuthDecode(buf) | ||||
| 				} | ||||
|  | ||||
| 				if wanted.FailFirst != nil { | ||||
| 					require.Error(t, err, pkInfo, pkt, wanted.Desc) | ||||
| 					require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				require.NoError(t, err, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				if pkt == Connect { | ||||
| 					// we use ProtocolVersion for controlling packet encoding, but we don't need to test | ||||
| 					// against it unless it's a connect packet. | ||||
| 					require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc) | ||||
| 				} | ||||
| 				require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc) | ||||
| 				require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc) | ||||
|  | ||||
| 				require.EqualValues(t, wanted.Packet.Properties, pk.Properties) | ||||
| 				require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties) | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestValidate(t *testing.T) { | ||||
| 	for _, pkt := range packetList { | ||||
| 		require.Contains(t, TPacketData, pkt) | ||||
| 		for _, wanted := range TPacketData[pkt] { | ||||
| 			t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) { | ||||
| 				if wanted.Group == "validate" || wanted.Primary { | ||||
| 					pk := wanted.Packet | ||||
| 					var err error | ||||
| 					switch pkt { | ||||
| 					case Connect: | ||||
| 						err = pk.ConnectValidate() | ||||
| 					case Publish: | ||||
| 						err = pk.PublishValidate(1024) | ||||
| 					case Subscribe: | ||||
| 						err = pk.SubscribeValidate() | ||||
| 					case Unsubscribe: | ||||
| 						err = pk.UnsubscribeValidate() | ||||
| 					case Auth: | ||||
| 						err = pk.AuthValidate() | ||||
| 					} | ||||
|  | ||||
| 					if wanted.Expect != nil { | ||||
| 						require.Error(t, err, pkInfo, pkt, wanted.Desc) | ||||
| 						require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAckValidatePubrec(t *testing.T) { | ||||
| 	for _, b := range []byte{ | ||||
| 		CodeSuccess.Code, | ||||
| 		CodeNoMatchingSubscribers.Code, | ||||
| 		ErrUnspecifiedError.Code, | ||||
| 		ErrImplementationSpecificError.Code, | ||||
| 		ErrNotAuthorized.Code, | ||||
| 		ErrTopicNameInvalid.Code, | ||||
| 		ErrPacketIdentifierInUse.Code, | ||||
| 		ErrQuotaExceeded.Code, | ||||
| 		ErrPayloadFormatInvalid.Code, | ||||
| 	} { | ||||
| 		pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b} | ||||
| 		require.True(t, pk.ReasonCodeValid()) | ||||
| 	} | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code} | ||||
| 	require.False(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestAckValidatePubrel(t *testing.T) { | ||||
| 	for _, b := range []byte{ | ||||
| 		CodeSuccess.Code, | ||||
| 		ErrPacketIdentifierNotFound.Code, | ||||
| 	} { | ||||
| 		pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b} | ||||
| 		require.True(t, pk.ReasonCodeValid()) | ||||
| 	} | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code} | ||||
| 	require.False(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestAckValidatePubcomp(t *testing.T) { | ||||
| 	for _, b := range []byte{ | ||||
| 		CodeSuccess.Code, | ||||
| 		ErrPacketIdentifierNotFound.Code, | ||||
| 	} { | ||||
| 		pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b} | ||||
| 		require.True(t, pk.ReasonCodeValid()) | ||||
| 	} | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code} | ||||
| 	require.False(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestAckValidateSuback(t *testing.T) { | ||||
| 	for _, b := range []byte{ | ||||
| 		CodeGrantedQos0.Code, | ||||
| 		CodeGrantedQos1.Code, | ||||
| 		CodeGrantedQos2.Code, | ||||
| 		ErrUnspecifiedError.Code, | ||||
| 		ErrImplementationSpecificError.Code, | ||||
| 		ErrNotAuthorized.Code, | ||||
| 		ErrTopicFilterInvalid.Code, | ||||
| 		ErrPacketIdentifierInUse.Code, | ||||
| 		ErrQuotaExceeded.Code, | ||||
| 		ErrSharedSubscriptionsNotSupported.Code, | ||||
| 		ErrSubscriptionIdentifiersNotSupported.Code, | ||||
| 		ErrWildcardSubscriptionsNotSupported.Code, | ||||
| 	} { | ||||
| 		pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b} | ||||
| 		require.True(t, pk.ReasonCodeValid()) | ||||
| 	} | ||||
|  | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code} | ||||
| 	require.False(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestAckValidateUnsuback(t *testing.T) { | ||||
| 	for _, b := range []byte{ | ||||
| 		CodeSuccess.Code, | ||||
| 		CodeNoSubscriptionExisted.Code, | ||||
| 		ErrUnspecifiedError.Code, | ||||
| 		ErrImplementationSpecificError.Code, | ||||
| 		ErrNotAuthorized.Code, | ||||
| 		ErrTopicFilterInvalid.Code, | ||||
| 		ErrPacketIdentifierInUse.Code, | ||||
| 	} { | ||||
| 		pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b} | ||||
| 		require.True(t, pk.ReasonCodeValid()) | ||||
| 	} | ||||
|  | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code} | ||||
| 	require.False(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestReasonCodeValidMisc(t *testing.T) { | ||||
| 	pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code} | ||||
| 	require.True(t, pk.ReasonCodeValid()) | ||||
| } | ||||
|  | ||||
| func TestCopy(t *testing.T) { | ||||
| 	for _, tt := range pkTable { | ||||
| 		pkc := tt.Packet.Copy(true) | ||||
|  | ||||
| 		require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc) | ||||
|  | ||||
| 		require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc) | ||||
| 		require.EqualValues(t, pkc.Properties, tt.Packet.Properties) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMergeSubscription(t *testing.T) { | ||||
| 	sub := Subscription{ | ||||
| 		Filter:            "a/b/c", | ||||
| 		RetainHandling:    0, | ||||
| 		Qos:               0, | ||||
| 		RetainAsPublished: false, | ||||
| 		NoLocal:           false, | ||||
| 		Identifier:        1, | ||||
| 	} | ||||
|  | ||||
| 	sub2 := Subscription{ | ||||
| 		Filter:            "a/b/d", | ||||
| 		RetainHandling:    0, | ||||
| 		Qos:               2, | ||||
| 		RetainAsPublished: false, | ||||
| 		NoLocal:           true, | ||||
| 		Identifier:        2, | ||||
| 	} | ||||
|  | ||||
| 	expect := Subscription{ | ||||
| 		Filter:            "a/b/c", | ||||
| 		RetainHandling:    0, | ||||
| 		Qos:               2, | ||||
| 		RetainAsPublished: false, | ||||
| 		NoLocal:           true, | ||||
| 		Identifier:        1, | ||||
| 		Identifiers: map[string]int{ | ||||
| 			"a/b/c": 1, | ||||
| 			"a/b/d": 2, | ||||
| 		}, | ||||
| 	} | ||||
| 	require.Equal(t, expect, sub.Merge(sub2)) | ||||
| } | ||||
							
								
								
									
										478
									
								
								packets/properties.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										478
									
								
								packets/properties.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,478 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	PropPayloadFormat          byte = 1 | ||||
| 	PropMessageExpiryInterval  byte = 2 | ||||
| 	PropContentType            byte = 3 | ||||
| 	PropResponseTopic          byte = 8 | ||||
| 	PropCorrelationData        byte = 9 | ||||
| 	PropSubscriptionIdentifier byte = 11 | ||||
| 	PropSessionExpiryInterval  byte = 17 | ||||
| 	PropAssignedClientID       byte = 18 | ||||
| 	PropServerKeepAlive        byte = 19 | ||||
| 	PropAuthenticationMethod   byte = 21 | ||||
| 	PropAuthenticationData     byte = 22 | ||||
| 	PropRequestProblemInfo     byte = 23 | ||||
| 	PropWillDelayInterval      byte = 24 | ||||
| 	PropRequestResponseInfo    byte = 25 | ||||
| 	PropResponseInfo           byte = 26 | ||||
| 	PropServerReference        byte = 28 | ||||
| 	PropReasonString           byte = 31 | ||||
| 	PropReceiveMaximum         byte = 33 | ||||
| 	PropTopicAliasMaximum      byte = 34 | ||||
| 	PropTopicAlias             byte = 35 | ||||
| 	PropMaximumQos             byte = 36 | ||||
| 	PropRetainAvailable        byte = 37 | ||||
| 	PropUser                   byte = 38 | ||||
| 	PropMaximumPacketSize      byte = 39 | ||||
| 	PropWildcardSubAvailable   byte = 40 | ||||
| 	PropSubIDAvailable         byte = 41 | ||||
| 	PropSharedSubAvailable     byte = 42 | ||||
| ) | ||||
|  | ||||
| // validPacketProperties indicates which properties are valid for which packet types. | ||||
| var validPacketProperties = map[byte]map[byte]byte{ | ||||
| 	PropPayloadFormat:          {Publish: 1}, | ||||
| 	PropMessageExpiryInterval:  {Publish: 1}, | ||||
| 	PropContentType:            {Publish: 1}, | ||||
| 	PropResponseTopic:          {Publish: 1}, | ||||
| 	PropCorrelationData:        {Publish: 1}, | ||||
| 	PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1}, | ||||
| 	PropSessionExpiryInterval:  {Connect: 1, Connack: 1, Disconnect: 1}, | ||||
| 	PropAssignedClientID:       {Connack: 1}, | ||||
| 	PropServerKeepAlive:        {Connack: 1}, | ||||
| 	PropAuthenticationMethod:   {Connect: 1, Connack: 1, Auth: 1}, | ||||
| 	PropAuthenticationData:     {Connect: 1, Connack: 1, Auth: 1}, | ||||
| 	PropRequestProblemInfo:     {Connect: 1}, | ||||
| 	PropWillDelayInterval:      {Connect: 1}, | ||||
| 	PropRequestResponseInfo:    {Connect: 1}, | ||||
| 	PropResponseInfo:           {Connack: 1}, | ||||
| 	PropServerReference:        {Connack: 1, Disconnect: 1}, | ||||
| 	PropReasonString:           {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1}, | ||||
| 	PropReceiveMaximum:         {Connect: 1, Connack: 1}, | ||||
| 	PropTopicAliasMaximum:      {Connect: 1, Connack: 1}, | ||||
| 	PropTopicAlias:             {Publish: 1}, | ||||
| 	PropMaximumQos:             {Connack: 1}, | ||||
| 	PropRetainAvailable:        {Connack: 1}, | ||||
| 	PropUser:                   {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1}, | ||||
| 	PropMaximumPacketSize:      {Connect: 1, Connack: 1}, | ||||
| 	PropWildcardSubAvailable:   {Connack: 1}, | ||||
| 	PropSubIDAvailable:         {Connack: 1}, | ||||
| 	PropSharedSubAvailable:     {Connack: 1}, | ||||
| } | ||||
|  | ||||
| // UserProperty is an arbitrary key-value pair for a packet user properties array. | ||||
| type UserProperty struct { // [MQTT-1.5.7-1] | ||||
| 	Key string `json:"k"` | ||||
| 	Val string `json:"v"` | ||||
| } | ||||
|  | ||||
| // Properties contains all of the mqtt v5 properties available for a packet. | ||||
| // Some properties have valid values of 0 or not-present. In this case, we opt for | ||||
| // property flags to indicate the usage of property. | ||||
| // Refer to mqtt v5 2.2.2.2 Property spec for more information. | ||||
| type Properties struct { | ||||
| 	CorrelationData           []byte         `json:"cd"` | ||||
| 	SubscriptionIdentifier    []int          `json:"si"` | ||||
| 	AuthenticationData        []byte         `json:"ad"` | ||||
| 	User                      []UserProperty `json:"user"` | ||||
| 	ContentType               string         `json:"ct"` | ||||
| 	ResponseTopic             string         `json:"rt"` | ||||
| 	AssignedClientID          string         `json:"aci"` | ||||
| 	AuthenticationMethod      string         `json:"am"` | ||||
| 	ResponseInfo              string         `json:"ri"` | ||||
| 	ServerReference           string         `json:"sr"` | ||||
| 	ReasonString              string         `json:"rs"` | ||||
| 	MessageExpiryInterval     uint32         `json:"me"` | ||||
| 	SessionExpiryInterval     uint32         `json:"sei"` | ||||
| 	WillDelayInterval         uint32         `json:"wdi"` | ||||
| 	MaximumPacketSize         uint32         `json:"mps"` | ||||
| 	ServerKeepAlive           uint16         `json:"ska"` | ||||
| 	ReceiveMaximum            uint16         `json:"rm"` | ||||
| 	TopicAliasMaximum         uint16         `json:"tam"` | ||||
| 	TopicAlias                uint16         `json:"ta"` | ||||
| 	PayloadFormat             byte           `json:"pf"` | ||||
| 	PayloadFormatFlag         bool           `json:"fpf"` | ||||
| 	SessionExpiryIntervalFlag bool           `json:"fsei"` | ||||
| 	ServerKeepAliveFlag       bool           `json:"fska"` | ||||
| 	RequestProblemInfo        byte           `json:"rpi"` | ||||
| 	RequestProblemInfoFlag    bool           `json:"frpi"` | ||||
| 	RequestResponseInfo       byte           `json:"rri"` | ||||
| 	TopicAliasFlag            bool           `json:"fta"` | ||||
| 	MaximumQos                byte           `json:"mqos"` | ||||
| 	MaximumQosFlag            bool           `json:"fmqos"` | ||||
| 	RetainAvailable           byte           `json:"ra"` | ||||
| 	RetainAvailableFlag       bool           `json:"fra"` | ||||
| 	WildcardSubAvailable      byte           `json:"wsa"` | ||||
| 	WildcardSubAvailableFlag  bool           `json:"fwsa"` | ||||
| 	SubIDAvailable            byte           `json:"sida"` | ||||
| 	SubIDAvailableFlag        bool           `json:"fsida"` | ||||
| 	SharedSubAvailable        byte           `json:"ssa"` | ||||
| 	SharedSubAvailableFlag    bool           `json:"fssa"` | ||||
| } | ||||
|  | ||||
| // Copy creates a new Properties struct with copies of the values. | ||||
| func (p *Properties) Copy(allowTransfer bool) Properties { | ||||
| 	pr := Properties{ | ||||
| 		PayloadFormat:             p.PayloadFormat, // [MQTT-3.3.2-4] | ||||
| 		PayloadFormatFlag:         p.PayloadFormatFlag, | ||||
| 		MessageExpiryInterval:     p.MessageExpiryInterval, | ||||
| 		ContentType:               p.ContentType,   // [MQTT-3.3.2-20] | ||||
| 		ResponseTopic:             p.ResponseTopic, // [MQTT-3.3.2-15] | ||||
| 		SessionExpiryInterval:     p.SessionExpiryInterval, | ||||
| 		SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag, | ||||
| 		AssignedClientID:          p.AssignedClientID, | ||||
| 		ServerKeepAlive:           p.ServerKeepAlive, | ||||
| 		ServerKeepAliveFlag:       p.ServerKeepAliveFlag, | ||||
| 		AuthenticationMethod:      p.AuthenticationMethod, | ||||
| 		RequestProblemInfo:        p.RequestProblemInfo, | ||||
| 		RequestProblemInfoFlag:    p.RequestProblemInfoFlag, | ||||
| 		WillDelayInterval:         p.WillDelayInterval, | ||||
| 		RequestResponseInfo:       p.RequestResponseInfo, | ||||
| 		ResponseInfo:              p.ResponseInfo, | ||||
| 		ServerReference:           p.ServerReference, | ||||
| 		ReasonString:              p.ReasonString, | ||||
| 		ReceiveMaximum:            p.ReceiveMaximum, | ||||
| 		TopicAliasMaximum:         p.TopicAliasMaximum, | ||||
| 		TopicAlias:                0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27] | ||||
| 		MaximumQos:                p.MaximumQos, | ||||
| 		MaximumQosFlag:            p.MaximumQosFlag, | ||||
| 		RetainAvailable:           p.RetainAvailable, | ||||
| 		RetainAvailableFlag:       p.RetainAvailableFlag, | ||||
| 		MaximumPacketSize:         p.MaximumPacketSize, | ||||
| 		WildcardSubAvailable:      p.WildcardSubAvailable, | ||||
| 		WildcardSubAvailableFlag:  p.WildcardSubAvailableFlag, | ||||
| 		SubIDAvailable:            p.SubIDAvailable, | ||||
| 		SubIDAvailableFlag:        p.SubIDAvailableFlag, | ||||
| 		SharedSubAvailable:        p.SharedSubAvailable, | ||||
| 		SharedSubAvailableFlag:    p.SharedSubAvailableFlag, | ||||
| 	} | ||||
|  | ||||
| 	if allowTransfer { | ||||
| 		pr.TopicAlias = p.TopicAlias | ||||
| 		pr.TopicAliasFlag = p.TopicAliasFlag | ||||
| 	} | ||||
|  | ||||
| 	if len(p.CorrelationData) > 0 { | ||||
| 		pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16] | ||||
| 	} | ||||
|  | ||||
| 	if len(p.SubscriptionIdentifier) > 0 { | ||||
| 		pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...) | ||||
| 	} | ||||
|  | ||||
| 	if len(p.AuthenticationData) > 0 { | ||||
| 		pr.AuthenticationData = append([]byte{}, p.AuthenticationData...) | ||||
| 	} | ||||
|  | ||||
| 	if len(p.User) > 0 { | ||||
| 		pr.User = []UserProperty{} | ||||
| 		for _, v := range p.User { | ||||
| 			pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17] | ||||
| 				Key: v.Key, | ||||
| 				Val: v.Val, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return pr | ||||
| } | ||||
|  | ||||
| // canEncode returns true if the property type is valid for the packet type. | ||||
| func (p *Properties) canEncode(pkt byte, k byte) bool { | ||||
| 	return validPacketProperties[k][pkt] == 1 | ||||
| } | ||||
|  | ||||
| // Encode encodes properties into a bytes buffer. | ||||
| func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { | ||||
| 	if p == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	pkt := pk.FixedHeader.Type | ||||
|  | ||||
| 	if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag { | ||||
| 		buf.WriteByte(PropPayloadFormat) | ||||
| 		buf.WriteByte(p.PayloadFormat) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 { | ||||
| 		buf.WriteByte(PropMessageExpiryInterval) | ||||
| 		buf.Write(encodeUint32(p.MessageExpiryInterval)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropContentType) && p.ContentType != "" { | ||||
| 		buf.WriteByte(PropContentType) | ||||
| 		buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && //  [MQTT-3.3.2-14] | ||||
| 		p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropResponseTopic) | ||||
| 		buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13] | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropCorrelationData) | ||||
| 		buf.Write(encodeBytes(p.CorrelationData)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 { | ||||
| 		for _, v := range p.SubscriptionIdentifier { | ||||
| 			if v > 0 { | ||||
| 				buf.WriteByte(PropSubscriptionIdentifier) | ||||
| 				encodeLength(&buf, int64(v)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2] | ||||
| 		buf.WriteByte(PropSessionExpiryInterval) | ||||
| 		buf.Write(encodeUint32(p.SessionExpiryInterval)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" { | ||||
| 		buf.WriteByte(PropAssignedClientID) | ||||
| 		buf.Write(encodeString(p.AssignedClientID)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag { | ||||
| 		buf.WriteByte(PropServerKeepAlive) | ||||
| 		buf.Write(encodeUint16(p.ServerKeepAlive)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" { | ||||
| 		buf.WriteByte(PropAuthenticationMethod) | ||||
| 		buf.Write(encodeString(p.AuthenticationMethod)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 { | ||||
| 		buf.WriteByte(PropAuthenticationData) | ||||
| 		buf.Write(encodeBytes(p.AuthenticationData)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag { | ||||
| 		buf.WriteByte(PropRequestProblemInfo) | ||||
| 		buf.WriteByte(p.RequestProblemInfo) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 { | ||||
| 		buf.WriteByte(PropWillDelayInterval) | ||||
| 		buf.Write(encodeUint32(p.WillDelayInterval)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 { | ||||
| 		buf.WriteByte(PropRequestResponseInfo) | ||||
| 		buf.WriteByte(p.RequestResponseInfo) | ||||
| 	} | ||||
|  | ||||
| 	if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28] | ||||
| 		buf.WriteByte(PropResponseInfo) | ||||
| 		buf.Write(encodeString(p.ResponseInfo)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 { | ||||
| 		buf.WriteByte(PropServerReference) | ||||
| 		buf.Write(encodeString(p.ServerReference)) | ||||
| 	} | ||||
|  | ||||
| 	// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2] | ||||
| 	// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2] | ||||
| 	if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" { | ||||
| 		b := encodeString(p.ReasonString) | ||||
| 		if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize { | ||||
| 			buf.WriteByte(PropReasonString) | ||||
| 			buf.Write(b) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 { | ||||
| 		buf.WriteByte(PropReceiveMaximum) | ||||
| 		buf.Write(encodeUint16(p.ReceiveMaximum)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 { | ||||
| 		buf.WriteByte(PropTopicAliasMaximum) | ||||
| 		buf.Write(encodeUint16(p.TopicAliasMaximum)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8] | ||||
| 		buf.WriteByte(PropTopicAlias) | ||||
| 		buf.Write(encodeUint16(p.TopicAlias)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 { | ||||
| 		buf.WriteByte(PropMaximumQos) | ||||
| 		buf.WriteByte(p.MaximumQos) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag { | ||||
| 		buf.WriteByte(PropRetainAvailable) | ||||
| 		buf.WriteByte(p.RetainAvailable) | ||||
| 	} | ||||
|  | ||||
| 	if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) { | ||||
| 		pb := bytes.NewBuffer([]byte{}) | ||||
| 		for _, v := range p.User { | ||||
| 			pb.WriteByte(PropUser) | ||||
| 			pb.Write(encodeString(v.Key)) | ||||
| 			pb.Write(encodeString(v.Val)) | ||||
| 		} | ||||
| 		// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3] | ||||
| 		// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3] | ||||
| 		if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize { | ||||
| 			buf.Write(pb.Bytes()) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 { | ||||
| 		buf.WriteByte(PropMaximumPacketSize) | ||||
| 		buf.Write(encodeUint32(p.MaximumPacketSize)) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag { | ||||
| 		buf.WriteByte(PropWildcardSubAvailable) | ||||
| 		buf.WriteByte(p.WildcardSubAvailable) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag { | ||||
| 		buf.WriteByte(PropSubIDAvailable) | ||||
| 		buf.WriteByte(p.SubIDAvailable) | ||||
| 	} | ||||
|  | ||||
| 	if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag { | ||||
| 		buf.WriteByte(PropSharedSubAvailable) | ||||
| 		buf.WriteByte(p.SharedSubAvailable) | ||||
| 	} | ||||
|  | ||||
| 	encodeLength(b, int64(buf.Len())) | ||||
| 	buf.WriteTo(b) // [MQTT-3.1.3-10] | ||||
| } | ||||
|  | ||||
| // Decode decodes property bytes into a properties struct. | ||||
| func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { | ||||
| 	if p == nil { | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	n, _, err = DecodeLength(b) | ||||
| 	if err != nil { | ||||
| 		return n, err | ||||
| 	} | ||||
|  | ||||
| 	if n == 0 { | ||||
| 		return n, nil | ||||
| 	} | ||||
|  | ||||
| 	bt := b.Bytes() | ||||
| 	var k byte | ||||
| 	for offset := 0; offset < n; { | ||||
| 		k, offset, err = decodeByte(bt, offset) | ||||
| 		if err != nil { | ||||
| 			return n, err | ||||
| 		} | ||||
|  | ||||
| 		if _, ok := validPacketProperties[k][pk]; !ok { | ||||
| 			return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty) | ||||
| 		} | ||||
|  | ||||
| 		switch k { | ||||
| 		case PropPayloadFormat: | ||||
| 			p.PayloadFormat, offset, err = decodeByte(bt, offset) | ||||
| 			p.PayloadFormatFlag = true | ||||
| 		case PropMessageExpiryInterval: | ||||
| 			p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset) | ||||
| 		case PropContentType: | ||||
| 			p.ContentType, offset, err = decodeString(bt, offset) | ||||
| 		case PropResponseTopic: | ||||
| 			p.ResponseTopic, offset, err = decodeString(bt, offset) | ||||
| 		case PropCorrelationData: | ||||
| 			p.CorrelationData, offset, err = decodeBytes(bt, offset) | ||||
| 		case PropSubscriptionIdentifier: | ||||
| 			if p.SubscriptionIdentifier == nil { | ||||
| 				p.SubscriptionIdentifier = []int{} | ||||
| 			} | ||||
|  | ||||
| 			n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:])) | ||||
| 			if err != nil { | ||||
| 				return n, err | ||||
| 			} | ||||
| 			p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n) | ||||
| 			offset += bu | ||||
| 		case PropSessionExpiryInterval: | ||||
| 			p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset) | ||||
| 			p.SessionExpiryIntervalFlag = true | ||||
| 		case PropAssignedClientID: | ||||
| 			p.AssignedClientID, offset, err = decodeString(bt, offset) | ||||
| 		case PropServerKeepAlive: | ||||
| 			p.ServerKeepAlive, offset, err = decodeUint16(bt, offset) | ||||
| 			p.ServerKeepAliveFlag = true | ||||
| 		case PropAuthenticationMethod: | ||||
| 			p.AuthenticationMethod, offset, err = decodeString(bt, offset) | ||||
| 		case PropAuthenticationData: | ||||
| 			p.AuthenticationData, offset, err = decodeBytes(bt, offset) | ||||
| 		case PropRequestProblemInfo: | ||||
| 			p.RequestProblemInfo, offset, err = decodeByte(bt, offset) | ||||
| 			p.RequestProblemInfoFlag = true | ||||
| 		case PropWillDelayInterval: | ||||
| 			p.WillDelayInterval, offset, err = decodeUint32(bt, offset) | ||||
| 		case PropRequestResponseInfo: | ||||
| 			p.RequestResponseInfo, offset, err = decodeByte(bt, offset) | ||||
| 		case PropResponseInfo: | ||||
| 			p.ResponseInfo, offset, err = decodeString(bt, offset) | ||||
| 		case PropServerReference: | ||||
| 			p.ServerReference, offset, err = decodeString(bt, offset) | ||||
| 		case PropReasonString: | ||||
| 			p.ReasonString, offset, err = decodeString(bt, offset) | ||||
| 		case PropReceiveMaximum: | ||||
| 			p.ReceiveMaximum, offset, err = decodeUint16(bt, offset) | ||||
| 		case PropTopicAliasMaximum: | ||||
| 			p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset) | ||||
| 		case PropTopicAlias: | ||||
| 			p.TopicAlias, offset, err = decodeUint16(bt, offset) | ||||
| 			p.TopicAliasFlag = true | ||||
| 		case PropMaximumQos: | ||||
| 			p.MaximumQos, offset, err = decodeByte(bt, offset) | ||||
| 			p.MaximumQosFlag = true | ||||
| 		case PropRetainAvailable: | ||||
| 			p.RetainAvailable, offset, err = decodeByte(bt, offset) | ||||
| 			p.RetainAvailableFlag = true | ||||
| 		case PropUser: | ||||
| 			var k, v string | ||||
| 			k, offset, err = decodeString(bt, offset) | ||||
| 			if err != nil { | ||||
| 				return n, err | ||||
| 			} | ||||
| 			v, offset, err = decodeString(bt, offset) | ||||
| 			p.User = append(p.User, UserProperty{Key: k, Val: v}) | ||||
| 		case PropMaximumPacketSize: | ||||
| 			p.MaximumPacketSize, offset, err = decodeUint32(bt, offset) | ||||
| 		case PropWildcardSubAvailable: | ||||
| 			p.WildcardSubAvailable, offset, err = decodeByte(bt, offset) | ||||
| 			p.WildcardSubAvailableFlag = true | ||||
| 		case PropSubIDAvailable: | ||||
| 			p.SubIDAvailable, offset, err = decodeByte(bt, offset) | ||||
| 			p.SubIDAvailableFlag = true | ||||
| 		case PropSharedSubAvailable: | ||||
| 			p.SharedSubAvailable, offset, err = decodeByte(bt, offset) | ||||
| 			p.SharedSubAvailableFlag = true | ||||
| 		} | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return n, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return n, nil | ||||
| } | ||||
							
								
								
									
										333
									
								
								packets/properties_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										333
									
								
								packets/properties_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,333 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	propertiesStruct = Properties{ | ||||
| 		PayloadFormat:             byte(1), // UTF-8 Format | ||||
| 		PayloadFormatFlag:         true, | ||||
| 		MessageExpiryInterval:     uint32(2), | ||||
| 		ContentType:               "text/plain", | ||||
| 		ResponseTopic:             "a/b/c", | ||||
| 		CorrelationData:           []byte("data"), | ||||
| 		SubscriptionIdentifier:    []int{322122}, | ||||
| 		SessionExpiryInterval:     uint32(120), | ||||
| 		SessionExpiryIntervalFlag: true, | ||||
| 		AssignedClientID:          "mochi-v5", | ||||
| 		ServerKeepAlive:           uint16(20), | ||||
| 		ServerKeepAliveFlag:       true, | ||||
| 		AuthenticationMethod:      "SHA-1", | ||||
| 		AuthenticationData:        []byte("auth-data"), | ||||
| 		RequestProblemInfo:        byte(1), | ||||
| 		RequestProblemInfoFlag:    true, | ||||
| 		WillDelayInterval:         uint32(600), | ||||
| 		RequestResponseInfo:       byte(1), | ||||
| 		ResponseInfo:              "response", | ||||
| 		ServerReference:           "mochi-2", | ||||
| 		ReasonString:              "reason", | ||||
| 		ReceiveMaximum:            uint16(500), | ||||
| 		TopicAliasMaximum:         uint16(999), | ||||
| 		TopicAlias:                uint16(3), | ||||
| 		TopicAliasFlag:            true, | ||||
| 		MaximumQos:                byte(1), | ||||
| 		MaximumQosFlag:            true, | ||||
| 		RetainAvailable:           byte(1), | ||||
| 		RetainAvailableFlag:       true, | ||||
| 		User: []UserProperty{ | ||||
| 			{ | ||||
| 				Key: "hello", | ||||
| 				Val: "世界", | ||||
| 			}, | ||||
| 			{ | ||||
| 				Key: "key2", | ||||
| 				Val: "value2", | ||||
| 			}, | ||||
| 		}, | ||||
| 		MaximumPacketSize:        uint32(32000), | ||||
| 		WildcardSubAvailable:     byte(1), | ||||
| 		WildcardSubAvailableFlag: true, | ||||
| 		SubIDAvailable:           byte(1), | ||||
| 		SubIDAvailableFlag:       true, | ||||
| 		SharedSubAvailable:       byte(1), | ||||
| 		SharedSubAvailableFlag:   true, | ||||
| 	} | ||||
|  | ||||
| 	propertiesBytes = []byte{ | ||||
| 		172, 1, // VBI | ||||
|  | ||||
| 		// Payload Format (1) (vbi:2) | ||||
| 		1, 1, | ||||
|  | ||||
| 		// Message Expiry (2) (vbi:7) | ||||
| 		2, 0, 0, 0, 2, | ||||
|  | ||||
| 		// Content Type (3) (vbi:20) | ||||
| 		3, | ||||
| 		0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n', | ||||
|  | ||||
| 		// Response Topic (8) (vbi:28) | ||||
| 		8, | ||||
| 		0, 5, 'a', '/', 'b', '/', 'c', | ||||
|  | ||||
| 		// Correlations Data (9) (vbi:35) | ||||
| 		9, | ||||
| 		0, 4, 'd', 'a', 't', 'a', | ||||
|  | ||||
| 		// Subscription Identifier (11) (vbi:39) | ||||
| 		11, | ||||
| 		202, 212, 19, | ||||
|  | ||||
| 		// Session Expiry Interval (17) (vbi:43) | ||||
| 		17, | ||||
| 		0, 0, 0, 120, | ||||
|  | ||||
| 		// Assigned Client ID (18) (vbi:55) | ||||
| 		18, | ||||
| 		0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5', | ||||
|  | ||||
| 		// Server Keep Alive (19) (vbi:58) | ||||
| 		19, | ||||
| 		0, 20, | ||||
|  | ||||
| 		// Authentication Method (21) (vbi:66) | ||||
| 		21, | ||||
| 		0, 5, 'S', 'H', 'A', '-', '1', | ||||
|  | ||||
| 		// Authentication Data (22) (vbi:78) | ||||
| 		22, | ||||
| 		0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', | ||||
|  | ||||
| 		// Request Problem Info (23) (vbi:80) | ||||
| 		23, 1, | ||||
|  | ||||
| 		// Will Delay Interval (24) (vbi:85) | ||||
| 		24, | ||||
| 		0, 0, 2, 88, | ||||
|  | ||||
| 		// Request Response Info (25) (vbi:87) | ||||
| 		25, 1, | ||||
|  | ||||
| 		// Response Info (26) (vbi:98) | ||||
| 		26, | ||||
| 		0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', | ||||
|  | ||||
| 		// Server Reference (28) (vbi:108) | ||||
| 		28, | ||||
| 		0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', | ||||
|  | ||||
| 		// Reason String (31) (vbi:117) | ||||
| 		31, | ||||
| 		0, 6, 'r', 'e', 'a', 's', 'o', 'n', | ||||
|  | ||||
| 		// Receive Maximum (33) (vbi:120) | ||||
| 		33, | ||||
| 		1, 244, | ||||
|  | ||||
| 		// Topic Alias Maximum (34) (vbi:123) | ||||
| 		34, | ||||
| 		3, 231, | ||||
|  | ||||
| 		// Topic Alias (35) (vbi:126) | ||||
| 		35, | ||||
| 		0, 3, | ||||
|  | ||||
| 		// Maximum Qos (36) (vbi:128) | ||||
| 		36, 1, | ||||
|  | ||||
| 		// Retain Available (37) (vbi: 130) | ||||
| 		37, 1, | ||||
|  | ||||
| 		// User Properties (38) (vbi:161) | ||||
| 		38, | ||||
| 		0, 5, 'h', 'e', 'l', 'l', 'o', | ||||
| 		0, 6, 228, 184, 150, 231, 149, 140, | ||||
| 		38, | ||||
| 		0, 4, 'k', 'e', 'y', '2', | ||||
| 		0, 6, 'v', 'a', 'l', 'u', 'e', '2', | ||||
|  | ||||
| 		// Maximum Packet Size (39) (vbi:166) | ||||
| 		39, | ||||
| 		0, 0, 125, 0, | ||||
|  | ||||
| 		// Wildcard Subscriptions Available (40)  (vbi:168) | ||||
| 		40, 1, | ||||
|  | ||||
| 		// Subscription ID Available (41) (vbi:170) | ||||
| 		41, 1, | ||||
|  | ||||
| 		// Shared Subscriptions Available (42) (vbi:172) | ||||
| 		42, 1, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| func init() { | ||||
| 	validPacketProperties[PropPayloadFormat][Reserved] = 1 | ||||
| 	validPacketProperties[PropMessageExpiryInterval][Reserved] = 1 | ||||
| 	validPacketProperties[PropContentType][Reserved] = 1 | ||||
| 	validPacketProperties[PropResponseTopic][Reserved] = 1 | ||||
| 	validPacketProperties[PropCorrelationData][Reserved] = 1 | ||||
| 	validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1 | ||||
| 	validPacketProperties[PropSessionExpiryInterval][Reserved] = 1 | ||||
| 	validPacketProperties[PropAssignedClientID][Reserved] = 1 | ||||
| 	validPacketProperties[PropServerKeepAlive][Reserved] = 1 | ||||
| 	validPacketProperties[PropAuthenticationMethod][Reserved] = 1 | ||||
| 	validPacketProperties[PropAuthenticationData][Reserved] = 1 | ||||
| 	validPacketProperties[PropRequestProblemInfo][Reserved] = 1 | ||||
| 	validPacketProperties[PropWillDelayInterval][Reserved] = 1 | ||||
| 	validPacketProperties[PropRequestResponseInfo][Reserved] = 1 | ||||
| 	validPacketProperties[PropResponseInfo][Reserved] = 1 | ||||
| 	validPacketProperties[PropServerReference][Reserved] = 1 | ||||
| 	validPacketProperties[PropReasonString][Reserved] = 1 | ||||
| 	validPacketProperties[PropReceiveMaximum][Reserved] = 1 | ||||
| 	validPacketProperties[PropTopicAliasMaximum][Reserved] = 1 | ||||
| 	validPacketProperties[PropTopicAlias][Reserved] = 1 | ||||
| 	validPacketProperties[PropMaximumQos][Reserved] = 1 | ||||
| 	validPacketProperties[PropRetainAvailable][Reserved] = 1 | ||||
| 	validPacketProperties[PropUser][Reserved] = 1 | ||||
| 	validPacketProperties[PropMaximumPacketSize][Reserved] = 1 | ||||
| 	validPacketProperties[PropWildcardSubAvailable][Reserved] = 1 | ||||
| 	validPacketProperties[PropSubIDAvailable][Reserved] = 1 | ||||
| 	validPacketProperties[PropSharedSubAvailable][Reserved] = 1 | ||||
| } | ||||
|  | ||||
| func TestEncodeProperties(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) | ||||
| 	require.Equal(t, propertiesBytes, b.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0) | ||||
| 	require.NotEqual(t, propertiesBytes, b.Bytes()) | ||||
| 	require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6})) | ||||
| 	require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5})) | ||||
| 	require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8})) | ||||
| } | ||||
|  | ||||
| func TestEncodePropertiesDisallowResponseInfo(t *testing.T) { | ||||
| 	props := propertiesStruct | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0) | ||||
| 	require.NotEqual(t, propertiesBytes, b.Bytes()) | ||||
| 	require.NotContains(t, b.Bytes(), []byte{8, 0, 5}) | ||||
| 	require.NotContains(t, b.Bytes(), []byte{9, 0, 4}) | ||||
| } | ||||
|  | ||||
| func TestEncodePropertiesNil(t *testing.T) { | ||||
| 	type tmp struct { | ||||
| 		p *Properties | ||||
| 	} | ||||
|  | ||||
| 	pr := tmp{} | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0) | ||||
| 	require.Equal(t, []byte{}, b.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestEncodeZeroProperties(t *testing.T) { | ||||
| 	// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero. | ||||
| 	props := new(Properties) | ||||
| 	b := bytes.NewBuffer([]byte{}) | ||||
| 	props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) | ||||
| 	require.Equal(t, []byte{0x00}, b.Bytes()) | ||||
| } | ||||
|  | ||||
| func TestDecodeProperties(t *testing.T) { | ||||
| 	b := bytes.NewBuffer(propertiesBytes) | ||||
|  | ||||
| 	props := new(Properties) | ||||
| 	n, err := props.Decode(Reserved, b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 172, n) | ||||
| 	require.EqualValues(t, propertiesStruct, *props) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesNil(t *testing.T) { | ||||
| 	b := bytes.NewBuffer(propertiesBytes) | ||||
|  | ||||
| 	type tmp struct { | ||||
| 		p *Properties | ||||
| 	} | ||||
|  | ||||
| 	pr := tmp{} | ||||
| 	n, err := pr.p.Decode(Reserved, b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, 0, n) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesBadInitialVBI(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, ErrMalformedVariableByteInteger, err) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesZeroLengthVBI(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{0}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, props, new(Properties)) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesBadKeyByte(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{64, 1}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesInvalidForPacket(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{1, 99}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| 	require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesGeneralFailure(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesBadSubscriptionID(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestDecodePropertiesBadUserProps(t *testing.T) { | ||||
| 	b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255}) | ||||
| 	props := new(Properties) | ||||
| 	_, err := props.Decode(Reserved, b) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestCopyProperties(t *testing.T) { | ||||
| 	require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true)) | ||||
| } | ||||
|  | ||||
| func TestCopyPropertiesNoTransfer(t *testing.T) { | ||||
| 	pkA := propertiesStruct | ||||
| 	pkB := pkA.Copy(false) | ||||
|  | ||||
| 	// Properties which should never be transferred from one connection to another | ||||
| 	require.Equal(t, uint16(0), pkB.TopicAlias) | ||||
| } | ||||
							
								
								
									
										3802
									
								
								packets/tpackets.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3802
									
								
								packets/tpackets.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										33
									
								
								packets/tpackets_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								packets/tpackets_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| // SPDX-License-Identifier: MIT | ||||
| // SPDX-FileCopyrightText: 2022 mochi-co | ||||
| // SPDX-FileContributor: mochi-co | ||||
|  | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func encodeTestOK(wanted TPacketCase) bool { | ||||
| 	if wanted.RawBytes == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if wanted.Group != "" && wanted.Group != "encode" { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func decodeTestOK(wanted TPacketCase) bool { | ||||
| 	if wanted.Group != "" && wanted.Group != "decode" { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func TestTPacketCaseGet(t *testing.T) { | ||||
| 	require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311)) | ||||
| 	require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128))) | ||||
| } | ||||
| @@ -1,36 +0,0 @@ | ||||
| package events | ||||
|  | ||||
| import ( | ||||
| 	"github.com/mochi-co/mqtt/server/internal/clients" | ||||
| 	"github.com/mochi-co/mqtt/server/internal/packets" | ||||
| ) | ||||
|  | ||||
| type Events struct { | ||||
| 	OnMessage // published message receieved. | ||||
| } | ||||
|  | ||||
| type Packet packets.Packet | ||||
|  | ||||
| type Client struct { | ||||
| 	ID       string | ||||
| 	Listener string | ||||
| } | ||||
|  | ||||
| // FromClient returns an event client from a client. | ||||
| func FromClient(cl clients.Client) Client { | ||||
| 	return Client{ | ||||
| 		ID:       cl.ID, | ||||
| 		Listener: cl.Listener, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // OnMessage function is called when a publish message is received. Note, | ||||
| // this hook is ONLY called by connected client publishers, it is not triggered when | ||||
| // using the direct s.Publish method. The function receives the sent message and the | ||||
| // data of the client who published it, and allows the packet to be modified | ||||
| // before it is dispatched to subscribers. If no modification is required, return | ||||
| // the original packet data. If an error occurs, the original packet will | ||||
| // be dispatched as if the event hook had not been triggered. | ||||
| // This function will block message dispatching until it returns. To minimise this, | ||||
| // have the function open a new goroutine on the embedding side. | ||||
| type OnMessage func(Client, Packet) (Packet, error) | ||||
| @@ -1,206 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes. | ||||
| 	DefaultBlockSize  int = 1024 * 8   // the default size per R/W block in bytes. | ||||
|  | ||||
| 	ErrOutOfRange        = errors.New("Indexes out of range") | ||||
| 	ErrInsufficientBytes = errors.New("Insufficient bytes to return") | ||||
| ) | ||||
|  | ||||
| // buffer contains core values and methods to be included in a reader or writer. | ||||
| type Buffer struct { | ||||
| 	Mu    sync.RWMutex // the buffer needs it's own mutex to work properly. | ||||
| 	ID    string       // the identifier of the buffer. This is used in debug output. | ||||
| 	size  int          // the size of the buffer. | ||||
| 	mask  int          // a bitmask of the buffer size (size-1). | ||||
| 	block int          // the size of the R/W block. | ||||
| 	buf   []byte       // the bytes buffer. | ||||
| 	tmp   []byte       // a temporary buffer. | ||||
| 	head  int64        // the current position in the sequence - a forever increasing index. | ||||
| 	tail  int64        // the committed position in the sequence - a forever increasing index. | ||||
| 	rcond *sync.Cond   // the sync condition for the buffer reader. | ||||
| 	wcond *sync.Cond   // the sync condition for the buffer writer. | ||||
| 	done  int64        // indicates that the buffer is closed. | ||||
| 	State int64        // indicates whether the buffer is reading from (1) or writing to (2). | ||||
| } | ||||
|  | ||||
| // NewBuffer returns a new instance of buffer. You should call NewReader or | ||||
| // NewWriter instead of this function. | ||||
| func NewBuffer(size, block int) Buffer { | ||||
| 	if size == 0 { | ||||
| 		size = DefaultBufferSize | ||||
| 	} | ||||
|  | ||||
| 	if block == 0 { | ||||
| 		block = DefaultBlockSize | ||||
| 	} | ||||
|  | ||||
| 	if size < 2*block { | ||||
| 		size = 2 * block | ||||
| 	} | ||||
|  | ||||
| 	return Buffer{ | ||||
| 		size:  size, | ||||
| 		mask:  size - 1, | ||||
| 		block: block, | ||||
| 		buf:   make([]byte, size), | ||||
| 		rcond: sync.NewCond(new(sync.Mutex)), | ||||
| 		wcond: sync.NewCond(new(sync.Mutex)), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewBufferFromSlice returns a new instance of buffer using a | ||||
| // pre-existing byte slice. | ||||
| func NewBufferFromSlice(block int, buf []byte) Buffer { | ||||
| 	l := len(buf) | ||||
|  | ||||
| 	if block == 0 { | ||||
| 		block = DefaultBlockSize | ||||
| 	} | ||||
|  | ||||
| 	b := Buffer{ | ||||
| 		size:  l, | ||||
| 		mask:  l - 1, | ||||
| 		block: block, | ||||
| 		buf:   buf, | ||||
| 		rcond: sync.NewCond(new(sync.Mutex)), | ||||
| 		wcond: sync.NewCond(new(sync.Mutex)), | ||||
| 	} | ||||
|  | ||||
| 	return b | ||||
| } | ||||
|  | ||||
| // Get will return the tail and head positions of the buffer. | ||||
| // This method is for use with testing. | ||||
| func (b *Buffer) GetPos() (int64, int64) { | ||||
| 	return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head) | ||||
| } | ||||
|  | ||||
| // SetPos sets the head and tail of the buffer. | ||||
| func (b *Buffer) SetPos(tail, head int64) { | ||||
| 	atomic.StoreInt64(&b.tail, tail) | ||||
| 	atomic.StoreInt64(&b.head, head) | ||||
| } | ||||
|  | ||||
| // Get returns the internal buffer. | ||||
| func (b *Buffer) Get() []byte { | ||||
| 	b.Mu.Lock() | ||||
| 	defer b.Mu.Unlock() | ||||
| 	return b.buf | ||||
| } | ||||
|  | ||||
| // Set writes bytes to a range of indexes in the byte buffer. | ||||
| func (b *Buffer) Set(p []byte, start, end int) error { | ||||
| 	b.Mu.Lock() | ||||
| 	defer b.Mu.Unlock() | ||||
|  | ||||
| 	if end > b.size || start > b.size { | ||||
| 		return ErrOutOfRange | ||||
| 	} | ||||
|  | ||||
| 	o := 0 | ||||
| 	for i := start; i < end; i++ { | ||||
| 		b.buf[i] = p[o] | ||||
| 		o++ | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Index returns the buffer-relative index of an integer. | ||||
| func (b *Buffer) Index(i int64) int { | ||||
| 	return b.mask & int(i) | ||||
| } | ||||
|  | ||||
| // awaitEmpty will block until there is at least n bytes between | ||||
| // the head and the tail (looking forward). | ||||
| func (b *Buffer) awaitEmpty(n int) error { | ||||
| 	// If the head has wrapped behind the tail, and next will overrun tail, | ||||
| 	// then wait until tail has moved. | ||||
| 	b.rcond.L.Lock() | ||||
| 	for !b.checkEmpty(n) { | ||||
| 		if atomic.LoadInt64(&b.done) == 1 { | ||||
| 			b.rcond.L.Unlock() | ||||
| 			return io.EOF | ||||
| 		} | ||||
| 		b.rcond.Wait() | ||||
| 	} | ||||
| 	b.rcond.L.Unlock() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // awaitFilled will block until there are at least n bytes between the | ||||
| // tail and the head (looking forward). | ||||
| func (b *Buffer) awaitFilled(n int) error { | ||||
| 	// Because awaitCapacity prevents the head from overrunning the t | ||||
| 	// able on write, we can simply ensure there is enough space | ||||
| 	// the forever-incrementing tail and head integers. | ||||
| 	b.wcond.L.Lock() | ||||
| 	for !b.checkFilled(n) { | ||||
| 		if atomic.LoadInt64(&b.done) == 1 { | ||||
| 			b.wcond.L.Unlock() | ||||
| 			return io.EOF | ||||
| 		} | ||||
|  | ||||
| 		b.wcond.Wait() | ||||
| 	} | ||||
| 	b.wcond.L.Unlock() | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // checkEmpty returns true if there are at least n bytes between the head and | ||||
| // the tail. | ||||
| func (b *Buffer) checkEmpty(n int) bool { | ||||
| 	head := atomic.LoadInt64(&b.head) | ||||
| 	next := head + int64(n) | ||||
| 	tail := atomic.LoadInt64(&b.tail) | ||||
| 	if next-tail > int64(b.size) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // checkFilled returns true if there are at least n bytes between the tail and | ||||
| // the head. | ||||
| func (b *Buffer) checkFilled(n int) bool { | ||||
| 	if atomic.LoadInt64(&b.tail)+int64(n) <= atomic.LoadInt64(&b.head) { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // CommitTail moves the tail position of the buffer n bytes. | ||||
| func (b *Buffer) CommitTail(n int) { | ||||
| 	atomic.AddInt64(&b.tail, int64(n)) | ||||
| 	b.rcond.L.Lock() | ||||
| 	b.rcond.Broadcast() | ||||
| 	b.rcond.L.Unlock() | ||||
| } | ||||
|  | ||||
| // CapDelta returns the difference between the head and tail. | ||||
| func (b *Buffer) CapDelta() int { | ||||
| 	return int(atomic.LoadInt64(&b.head) - atomic.LoadInt64(&b.tail)) | ||||
| } | ||||
|  | ||||
| // Stop signals the buffer to stop processing. | ||||
| func (b *Buffer) Stop() { | ||||
| 	atomic.StoreInt64(&b.done, 1) | ||||
| 	b.rcond.L.Lock() | ||||
| 	b.rcond.Broadcast() | ||||
| 	b.rcond.L.Unlock() | ||||
| 	b.wcond.L.Lock() | ||||
| 	b.wcond.Broadcast() | ||||
| 	b.wcond.L.Unlock() | ||||
| } | ||||
| @@ -1,304 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	//"fmt" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewBuffer(t *testing.T) { | ||||
| 	var size int = 16 | ||||
| 	var block int = 4 | ||||
| 	buf := NewBuffer(size, block) | ||||
|  | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.NotNil(t, buf.rcond) | ||||
| 	require.NotNil(t, buf.wcond) | ||||
| 	require.Equal(t, size, len(buf.buf)) | ||||
| 	require.Equal(t, size, buf.size) | ||||
| 	require.Equal(t, block, buf.block) | ||||
| } | ||||
|  | ||||
| func TestNewBuffer0Size(t *testing.T) { | ||||
| 	buf := NewBuffer(0, 0) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, DefaultBufferSize, buf.size) | ||||
| 	require.Equal(t, DefaultBlockSize, buf.block) | ||||
| } | ||||
|  | ||||
| func TestNewBufferUndersize(t *testing.T) { | ||||
| 	buf := NewBuffer(DefaultBlockSize+10, DefaultBlockSize) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, DefaultBlockSize*2, buf.size) | ||||
| 	require.Equal(t, DefaultBlockSize, buf.block) | ||||
| } | ||||
|  | ||||
| func TestNewBufferFromSlice(t *testing.T) { | ||||
| 	b := NewBytesPool(256) | ||||
| 	buf := NewBufferFromSlice(DefaultBlockSize, b.Get()) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, 256, cap(buf.buf)) | ||||
| } | ||||
|  | ||||
| func TestNewBufferFromSlice0Size(t *testing.T) { | ||||
| 	b := NewBytesPool(256) | ||||
| 	buf := NewBufferFromSlice(0, b.Get()) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, 256, cap(buf.buf)) | ||||
| } | ||||
|  | ||||
| func TestGetPos(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	tail, head := buf.GetPos() | ||||
| 	require.Equal(t, int64(0), tail) | ||||
| 	require.Equal(t, int64(0), head) | ||||
|  | ||||
| 	atomic.StoreInt64(&buf.tail, 3) | ||||
| 	atomic.StoreInt64(&buf.head, 11) | ||||
|  | ||||
| 	tail, head = buf.GetPos() | ||||
| 	require.Equal(t, int64(3), tail) | ||||
| 	require.Equal(t, int64(11), head) | ||||
| } | ||||
|  | ||||
| func TestGet(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	require.Equal(t, make([]byte, 16), buf.Get()) | ||||
|  | ||||
| 	buf.buf[0] = 1 | ||||
| 	buf.buf[15] = 1 | ||||
| 	require.Equal(t, []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, buf.Get()) | ||||
| } | ||||
|  | ||||
| func TestSetPos(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	require.Equal(t, int64(0), atomic.LoadInt64(&buf.tail)) | ||||
| 	require.Equal(t, int64(0), atomic.LoadInt64(&buf.head)) | ||||
|  | ||||
| 	buf.SetPos(4, 8) | ||||
| 	require.Equal(t, int64(4), atomic.LoadInt64(&buf.tail)) | ||||
| 	require.Equal(t, int64(8), atomic.LoadInt64(&buf.head)) | ||||
| } | ||||
|  | ||||
| func TestSet(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	err := buf.Set([]byte{1, 1, 1, 1}, 17, 19) | ||||
| 	require.Error(t, err) | ||||
|  | ||||
| 	err = buf.Set([]byte{1, 1, 1, 1}, 4, 8) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, []byte{0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, buf.buf) | ||||
| } | ||||
|  | ||||
| func TestIndex(t *testing.T) { | ||||
| 	buf := NewBuffer(1024, 4) | ||||
| 	require.Equal(t, 512, buf.Index(512)) | ||||
| 	require.Equal(t, 0, buf.Index(1024)) | ||||
| 	require.Equal(t, 6, buf.Index(1030)) | ||||
| 	require.Equal(t, 6, buf.Index(61446)) | ||||
| } | ||||
|  | ||||
| func TestAwaitFilled(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		n     int | ||||
| 		await int | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 4, n: 4, await: 1, desc: "OK 0, 4"}, | ||||
| 		{tail: 8, head: 11, n: 4, await: 1, desc: "OK 8, 11"}, | ||||
| 		{tail: 102, head: 103, n: 4, await: 3, desc: "OK 102, 103"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		//fmt.Println(i) | ||||
| 		buf := NewBuffer(16, 4) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		o := make(chan error) | ||||
| 		go func() { | ||||
| 			o <- buf.awaitFilled(4) | ||||
| 		}() | ||||
|  | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		atomic.AddInt64(&buf.head, int64(tt.await)) | ||||
| 		buf.wcond.L.Lock() | ||||
| 		buf.wcond.Broadcast() | ||||
| 		buf.wcond.L.Unlock() | ||||
|  | ||||
| 		require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAwaitFilledEnded(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- buf.awaitFilled(4) | ||||
| 	}() | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	atomic.StoreInt64(&buf.done, 1) | ||||
| 	buf.wcond.L.Lock() | ||||
| 	buf.wcond.Broadcast() | ||||
| 	buf.wcond.L.Unlock() | ||||
|  | ||||
| 	require.Error(t, <-o) | ||||
| } | ||||
|  | ||||
| func TestAwaitEmptyOK(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		await int | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 0, await: 0, desc: "OK 0, 0"}, | ||||
| 		{tail: 0, head: 5, await: 0, desc: "OK 0, 5"}, | ||||
| 		{tail: 0, head: 14, await: 3, desc: "OK wrap 0, 14 "}, | ||||
| 		{tail: 22, head: 35, await: 2, desc: "OK wrap 0, 14 "}, | ||||
| 		{tail: 15, head: 17, await: 7, desc: "OK 15,2"}, | ||||
| 		{tail: 0, head: 10, await: 2, desc: "OK 0, 10"}, | ||||
| 		{tail: 1, head: 15, await: 4, desc: "OK 2, 14"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf := NewBuffer(16, 4) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		o := make(chan error) | ||||
| 		go func() { | ||||
| 			o <- buf.awaitEmpty(4) | ||||
| 		}() | ||||
|  | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		atomic.AddInt64(&buf.tail, int64(tt.await)) | ||||
| 		buf.rcond.L.Lock() | ||||
| 		buf.rcond.Broadcast() | ||||
| 		buf.rcond.L.Unlock() | ||||
|  | ||||
| 		require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAwaitEmptyEnded(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	buf.SetPos(1, 15) | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- buf.awaitEmpty(4) | ||||
| 	}() | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	atomic.StoreInt64(&buf.done, 1) | ||||
| 	buf.rcond.L.Lock() | ||||
| 	buf.rcond.Broadcast() | ||||
| 	buf.rcond.L.Unlock() | ||||
|  | ||||
| 	require.Error(t, <-o) | ||||
| } | ||||
|  | ||||
| func TestCheckEmpty(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		head int64 | ||||
| 		tail int64 | ||||
| 		want bool | ||||
| 		desc string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 0, want: true, desc: "0, 0 true"}, | ||||
| 		{tail: 3, head: 4, want: true, desc: "4, 3 true"}, | ||||
| 		{tail: 15, head: 17, want: true, desc: "15, 17(1) true"}, | ||||
| 		{tail: 1, head: 30, want: false, desc: "1, 30(14) false"}, | ||||
| 		{tail: 15, head: 30, want: false, desc: "15, 30(14) false; head has caught up to tail"}, | ||||
| 	} | ||||
| 	for i, tt := range tests { | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		require.Equal(t, tt.want, buf.checkEmpty(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCheckFilled(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		head int64 | ||||
| 		tail int64 | ||||
| 		want bool | ||||
| 		desc string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 0, want: false, desc: "0, 0 false"}, | ||||
| 		{tail: 0, head: 4, want: true, desc: "0, 4 true"}, | ||||
| 		{tail: 14, head: 16, want: false, desc: "14,16 false"}, | ||||
| 		{tail: 14, head: 18, want: true, desc: "14,16 true"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		require.Equal(t, tt.want, buf.checkFilled(4), "Mismatched bool wanted [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func TestCommitTail(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		n     int | ||||
| 		next  int64 | ||||
| 		await int | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 5, n: 4, next: 4, await: 0, desc: "OK 0, 4"}, | ||||
| 		{tail: 0, head: 5, n: 6, next: 6, await: 1, desc: "OK 0, 5"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf := NewBuffer(16, 4) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		go func() { | ||||
| 			buf.CommitTail(tt.n) | ||||
| 		}() | ||||
|  | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		for j := 0; j < tt.await; j++ { | ||||
| 			atomic.AddInt64(&buf.head, 1) | ||||
| 			buf.wcond.L.Lock() | ||||
| 			buf.wcond.Broadcast() | ||||
| 			buf.wcond.L.Unlock() | ||||
| 		} | ||||
| 		require.Equal(t, tt.next, atomic.LoadInt64(&buf.tail), "Next tail mismatch [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| /* | ||||
| func TestCommitTailEnded(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		o <- buf.CommitTail(5) | ||||
| 	}() | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	atomic.StoreInt64(&buf.done, 1) | ||||
| 	buf.wcond.L.Lock() | ||||
| 	buf.wcond.Broadcast() | ||||
| 	buf.wcond.L.Unlock() | ||||
|  | ||||
| 	require.Error(t, <-o) | ||||
| } | ||||
| */ | ||||
| func TestCapDelta(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
|  | ||||
| 	require.Equal(t, 0, buf.CapDelta()) | ||||
|  | ||||
| 	buf.SetPos(10, 15) | ||||
| 	require.Equal(t, 5, buf.CapDelta()) | ||||
| } | ||||
|  | ||||
| func TestStop(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	buf.Stop() | ||||
| 	require.Equal(t, int64(1), buf.done) | ||||
| } | ||||
| @@ -1,34 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| // BytesPool is a pool of []byte. | ||||
| type BytesPool struct { | ||||
| 	pool sync.Pool | ||||
| } | ||||
|  | ||||
| // NewBytesPool returns a sync.pool of []byte. | ||||
| func NewBytesPool(n int) BytesPool { | ||||
| 	return BytesPool{ | ||||
| 		pool: sync.Pool{ | ||||
| 			New: func() interface{} { | ||||
| 				return make([]byte, n) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get returns a pooled bytes.Buffer. | ||||
| func (b *BytesPool) Get() []byte { | ||||
| 	return b.pool.Get().([]byte) | ||||
| } | ||||
|  | ||||
| // Put puts the byte slice back into the pool. | ||||
| func (b *BytesPool) Put(x []byte) { | ||||
| 	for i := range x { | ||||
| 		x[i] = 0 | ||||
| 	} | ||||
| 	b.pool.Put(x) | ||||
| } | ||||
| @@ -1,46 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewBytesPool(t *testing.T) { | ||||
| 	bpool := NewBytesPool(256) | ||||
| 	require.NotNil(t, bpool.pool) | ||||
| } | ||||
|  | ||||
| func BenchmarkNewBytesPool(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		NewBytesPool(256) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNewBytesPoolGet(t *testing.T) { | ||||
| 	bpool := NewBytesPool(256) | ||||
| 	buf := bpool.Get() | ||||
|  | ||||
| 	require.Equal(t, make([]byte, 256), buf) | ||||
| } | ||||
|  | ||||
| func BenchmarkBytesPoolGet(b *testing.B) { | ||||
| 	bpool := NewBytesPool(256) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		bpool.Get() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNewBytesPoolPut(t *testing.T) { | ||||
| 	bpool := NewBytesPool(256) | ||||
| 	buf := bpool.Get() | ||||
| 	bpool.Put(buf) | ||||
| } | ||||
|  | ||||
| func BenchmarkBytesPoolPut(b *testing.B) { | ||||
| 	bpool := NewBytesPool(256) | ||||
| 	buf := bpool.Get() | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		bpool.Put(buf) | ||||
| 	} | ||||
| } | ||||
| @@ -1,96 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| // Reader is a circular buffer for reading data from an io.Reader. | ||||
| type Reader struct { | ||||
| 	Buffer | ||||
| } | ||||
|  | ||||
| // NewReader returns a new Circular Reader. | ||||
| func NewReader(size, block int) *Reader { | ||||
| 	b := NewBuffer(size, block) | ||||
| 	b.ID = "\treader" | ||||
| 	return &Reader{ | ||||
| 		b, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewReaderFromSlice returns a new Circular Reader using a pre-exising | ||||
| // byte slice. | ||||
| func NewReaderFromSlice(block int, p []byte) *Reader { | ||||
| 	b := NewBufferFromSlice(block, p) | ||||
| 	b.ID = "\treader" | ||||
| 	return &Reader{ | ||||
| 		b, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadFrom reads bytes from an io.Reader and commits them to the buffer when | ||||
| // there is sufficient capacity to do so. | ||||
| func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) { | ||||
| 	atomic.StoreInt64(&b.State, 1) | ||||
| 	defer atomic.StoreInt64(&b.State, 0) | ||||
| 	for { | ||||
| 		if atomic.LoadInt64(&b.done) == 1 { | ||||
| 			return total, nil | ||||
| 		} | ||||
|  | ||||
| 		// Wait until there's enough capacity in the buffer before | ||||
| 		// trying to read more bytes from the io.Reader. | ||||
| 		err := b.awaitEmpty(b.block) | ||||
| 		if err != nil { | ||||
| 			// b.done is the only error condition for awaitCapacity | ||||
| 			// so loop around and return properly. | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// If the block will overrun the circle end, just fill up | ||||
| 		// and collect the rest on the next pass. | ||||
| 		start := b.Index(atomic.LoadInt64(&b.head)) | ||||
| 		end := start + b.block | ||||
| 		if end > b.size { | ||||
| 			end = b.size | ||||
| 		} | ||||
|  | ||||
| 		// Read into the buffer between the start and end indexes only. | ||||
| 		n, err := r.Read(b.buf[start:end]) | ||||
| 		total += int64(n) // incr total bytes read. | ||||
| 		if err != nil { | ||||
| 			return total, nil | ||||
| 		} | ||||
|  | ||||
| 		// Move the head forward however many bytes were read. | ||||
| 		atomic.AddInt64(&b.head, int64(n)) | ||||
|  | ||||
| 		b.wcond.L.Lock() | ||||
| 		b.wcond.Broadcast() | ||||
| 		b.wcond.L.Unlock() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Read reads n bytes from the buffer, and will block until at n bytes | ||||
| // exist in the buffer to read. | ||||
| func (b *Buffer) Read(n int) (p []byte, err error) { | ||||
| 	err = b.awaitFilled(n) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	tail := atomic.LoadInt64(&b.tail) | ||||
| 	next := tail + int64(n) | ||||
|  | ||||
| 	// If the read overruns the buffer, get everything until the end | ||||
| 	// and then whatever is left from the start. | ||||
| 	if b.Index(tail) > b.Index(next) { | ||||
| 		b.tmp = b.buf[b.Index(tail):] | ||||
| 		b.tmp = append(b.tmp, b.buf[:b.Index(next)]...) | ||||
| 	} else { | ||||
| 		b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read. | ||||
| 	} | ||||
|  | ||||
| 	return b.tmp, nil | ||||
| } | ||||
| @@ -1,125 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewReader(t *testing.T) { | ||||
| 	var size = 16 | ||||
| 	var block = 4 | ||||
| 	buf := NewReader(size, block) | ||||
|  | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, size, len(buf.buf)) | ||||
| 	require.Equal(t, size, buf.size) | ||||
| 	require.Equal(t, block, buf.block) | ||||
| } | ||||
|  | ||||
| func TestNewReaderFromSlice(t *testing.T) { | ||||
| 	b := NewBytesPool(256) | ||||
| 	buf := NewReaderFromSlice(DefaultBlockSize, b.Get()) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, 256, cap(buf.buf)) | ||||
| } | ||||
|  | ||||
| func TestReadFrom(t *testing.T) { | ||||
| 	buf := NewReader(16, 4) | ||||
|  | ||||
| 	b4 := bytes.Repeat([]byte{'-'}, 4) | ||||
| 	br := bytes.NewReader(b4) | ||||
|  | ||||
| 	_, err := buf.ReadFrom(br) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4]) | ||||
| 	require.Equal(t, int64(4), buf.head) | ||||
|  | ||||
| 	br.Reset(b4) | ||||
| 	_, err = buf.ReadFrom(br) | ||||
| 	require.Equal(t, int64(8), buf.head) | ||||
|  | ||||
| 	br.Reset(b4) | ||||
| 	_, err = buf.ReadFrom(br) | ||||
| 	require.Equal(t, int64(12), buf.head) | ||||
| } | ||||
|  | ||||
| func TestReadFromWrap(t *testing.T) { | ||||
| 	buf := NewReader(16, 4) | ||||
| 	buf.buf = bytes.Repeat([]byte{'-'}, 16) | ||||
| 	buf.SetPos(8, 14) | ||||
| 	br := bytes.NewReader(bytes.Repeat([]byte{'/'}, 8)) | ||||
|  | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		_, err := buf.ReadFrom(br) | ||||
| 		o <- err | ||||
| 	}() | ||||
| 	time.Sleep(time.Millisecond * 100) | ||||
| 	go func() { | ||||
| 		atomic.StoreInt64(&buf.done, 1) | ||||
| 		buf.rcond.L.Lock() | ||||
| 		buf.rcond.Broadcast() | ||||
| 		buf.rcond.L.Unlock() | ||||
| 	}() | ||||
| 	<-o | ||||
| 	require.Equal(t, []byte{'/', '/', '/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '/', '/'}, buf.Get()) | ||||
| 	require.Equal(t, int64(22), atomic.LoadInt64(&buf.head)) | ||||
| 	require.Equal(t, 6, buf.Index(atomic.LoadInt64(&buf.head))) | ||||
| } | ||||
|  | ||||
| func TestReadOK(t *testing.T) { | ||||
| 	buf := NewReader(16, 4) | ||||
| 	buf.buf = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		n     int | ||||
| 		bytes []byte | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 4, n: 4, bytes: []byte{'a', 'b', 'c', 'd'}, desc: "0, 4 OK"}, | ||||
| 		{tail: 3, head: 15, n: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, desc: "3, 15 OK"}, | ||||
| 		{tail: 14, head: 15, n: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, desc: "14, 2 wrapped OK"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		o := make(chan []byte) | ||||
| 		go func() { | ||||
| 			p, _ := buf.Read(tt.n) | ||||
| 			o <- p | ||||
| 		}() | ||||
|  | ||||
| 		time.Sleep(time.Millisecond) | ||||
| 		atomic.StoreInt64(&buf.head, buf.head+int64(tt.n)) | ||||
|  | ||||
| 		buf.wcond.L.Lock() | ||||
| 		buf.wcond.Broadcast() | ||||
| 		buf.wcond.L.Unlock() | ||||
|  | ||||
| 		done := <-o | ||||
| 		require.Equal(t, tt.bytes, done, "Peeked bytes mismatch [i:%d] %s", i, tt.desc) | ||||
|  | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestReadEnded(t *testing.T) { | ||||
| 	buf := NewBuffer(16, 4) | ||||
| 	o := make(chan error) | ||||
| 	go func() { | ||||
| 		_, err := buf.Read(4) | ||||
| 		o <- err | ||||
| 	}() | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	atomic.StoreInt64(&buf.done, 1) | ||||
| 	buf.wcond.L.Lock() | ||||
| 	buf.wcond.Broadcast() | ||||
| 	buf.wcond.L.Unlock() | ||||
|  | ||||
| 	require.Error(t, <-o) | ||||
| } | ||||
| @@ -1,108 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| // Writer is a circular buffer for writing data to an io.Writer. | ||||
| type Writer struct { | ||||
| 	Buffer | ||||
| } | ||||
|  | ||||
| // NewWriter returns a pointer to a new Circular Writer. | ||||
| func NewWriter(size, block int) *Writer { | ||||
| 	b := NewBuffer(size, block) | ||||
| 	b.ID = "writer" | ||||
| 	return &Writer{ | ||||
| 		b, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewWriterFromSlice returns a new Circular Writer using a pre-exising | ||||
| // byte slice. | ||||
| func NewWriterFromSlice(block int, p []byte) *Writer { | ||||
| 	b := NewBufferFromSlice(block, p) | ||||
| 	b.ID = "writer" | ||||
| 	return &Writer{ | ||||
| 		b, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WriteTo writes the contents of the buffer to an io.Writer. | ||||
| func (b *Writer) WriteTo(w io.Writer) (total int, err error) { | ||||
| 	atomic.StoreInt64(&b.State, 2) | ||||
| 	defer atomic.StoreInt64(&b.State, 0) | ||||
| 	for { | ||||
| 		if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 { | ||||
| 			return total, io.EOF | ||||
| 		} | ||||
|  | ||||
| 		// Read from the buffer until there is at least 1 byte to write. | ||||
| 		err = b.awaitFilled(1) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Get all the bytes between the tail and head, wrapping if necessary. | ||||
| 		tail := atomic.LoadInt64(&b.tail) | ||||
| 		rTail := b.Index(tail) | ||||
| 		rHead := b.Index(atomic.LoadInt64(&b.head)) | ||||
| 		n := b.CapDelta() | ||||
| 		p := make([]byte, 0, n) | ||||
|  | ||||
| 		if rTail > rHead { | ||||
| 			p = append(p, b.buf[rTail:]...) | ||||
| 			p = append(p, b.buf[:rHead]...) | ||||
| 		} else { | ||||
| 			p = append(p, b.buf[rTail:rHead]...) | ||||
| 		} | ||||
|  | ||||
| 		//fmt.Println("writing", p) | ||||
| 		n, err = w.Write(p) | ||||
| 		total += n | ||||
| 		if err != nil { | ||||
| 			fmt.Println("writing err", err) | ||||
| 			return | ||||
| 		} | ||||
| 		//fmt.Println("written", n) | ||||
|  | ||||
| 		// Move the tail forward the bytes written and broadcast change. | ||||
| 		atomic.StoreInt64(&b.tail, tail+int64(n)) | ||||
| 		b.rcond.L.Lock() | ||||
| 		b.rcond.Broadcast() | ||||
| 		b.rcond.L.Unlock() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Write writes the buffer to the buffer p, returning the number of bytes written. | ||||
| // The bytes written to the buffer are picked up by WriteTo. | ||||
| func (b *Writer) Write(p []byte) (total int, err error) { | ||||
| 	err = b.awaitEmpty(len(p)) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	total = b.writeBytes(p) | ||||
| 	atomic.AddInt64(&b.head, int64(total)) | ||||
| 	b.wcond.L.Lock() | ||||
| 	b.wcond.Broadcast() | ||||
| 	b.wcond.L.Unlock() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // writeBytes writes bytes to the buffer from the start position, and returns | ||||
| // the new head position. This function does not wait for capacity and will | ||||
| // overwrite any existing bytes. | ||||
| func (b *Writer) writeBytes(p []byte) int { | ||||
| 	var o int | ||||
| 	var n int | ||||
| 	for i := 0; i < len(p); i++ { | ||||
| 		o = b.Index(atomic.LoadInt64(&b.head) + int64(i)) | ||||
| 		b.buf[o] = p[i] | ||||
| 		n++ | ||||
| 	} | ||||
|  | ||||
| 	return n | ||||
| } | ||||
| @@ -1,155 +0,0 @@ | ||||
| package circ | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"net" | ||||
| 	"sync/atomic" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewWriter(t *testing.T) { | ||||
| 	var size = 16 | ||||
| 	var block = 4 | ||||
| 	buf := NewWriter(size, block) | ||||
|  | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, size, len(buf.buf)) | ||||
| 	require.Equal(t, size, buf.size) | ||||
| 	require.Equal(t, block, buf.block) | ||||
| } | ||||
|  | ||||
| func TestNewWriterFromSlice(t *testing.T) { | ||||
| 	b := NewBytesPool(256) | ||||
| 	buf := NewWriterFromSlice(DefaultBlockSize, b.Get()) | ||||
| 	require.NotNil(t, buf.buf) | ||||
| 	require.Equal(t, 256, cap(buf.buf)) | ||||
| } | ||||
|  | ||||
| func TestWriteTo(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		bytes []byte | ||||
| 		await int | ||||
| 		total int | ||||
| 		err   error | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 5, bytes: []byte{'a', 'b', 'c', 'd', 'e'}, desc: "0,5 OK"}, | ||||
| 		{tail: 14, head: 21, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd', 'e'}, desc: "14,16(2) OK"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} | ||||
| 		buf := NewWriter(16, 4) | ||||
| 		buf.Set(bb, 0, 16) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
|  | ||||
| 		var b bytes.Buffer | ||||
| 		w := bufio.NewWriter(&b) | ||||
|  | ||||
| 		nc := make(chan int) | ||||
| 		go func() { | ||||
| 			n, _ := buf.WriteTo(w) | ||||
| 			nc <- n | ||||
| 		}() | ||||
|  | ||||
| 		time.Sleep(time.Millisecond * 100) | ||||
| 		atomic.StoreInt64(&buf.done, 1) | ||||
| 		buf.wcond.L.Lock() | ||||
| 		buf.wcond.Broadcast() | ||||
| 		buf.wcond.L.Unlock() | ||||
|  | ||||
| 		w.Flush() | ||||
| 		require.Equal(t, tt.bytes, b.Bytes(), "Written bytes mismatch [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWriteToEndedFirst(t *testing.T) { | ||||
| 	buf := NewWriter(16, 4) | ||||
| 	buf.done = 1 | ||||
|  | ||||
| 	var b bytes.Buffer | ||||
| 	w := bufio.NewWriter(&b) | ||||
| 	_, err := buf.WriteTo(w) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestWriteToBadWriter(t *testing.T) { | ||||
| 	bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'} | ||||
| 	buf := NewWriter(16, 4) | ||||
| 	buf.Set(bb, 0, 16) | ||||
| 	buf.SetPos(0, 6) | ||||
| 	r, w := net.Pipe() | ||||
|  | ||||
| 	w.Close() | ||||
| 	_, err := buf.WriteTo(w) | ||||
| 	require.Error(t, err) | ||||
| 	r.Close() | ||||
| } | ||||
|  | ||||
| func TestWrite(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		rHead int64 | ||||
| 		bytes []byte | ||||
| 		want  []byte | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 0, rHead: 4, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, desc: "0>4 OK"}, | ||||
| 		{tail: 4, head: 14, rHead: 2, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 'b'}, desc: "14>2 OK"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf := NewWriter(16, 4) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
|  | ||||
| 		o := make(chan []interface{}) | ||||
| 		go func() { | ||||
| 			nn, err := buf.Write(tt.bytes) | ||||
| 			o <- []interface{}{nn, err} | ||||
| 		}() | ||||
|  | ||||
| 		done := <-o | ||||
| 		require.Equal(t, tt.want, buf.buf, "Wanted written mismatch [i:%d] %s", i, tt.desc) | ||||
| 		require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestWriteEnded(t *testing.T) { | ||||
| 	buf := NewWriter(16, 4) | ||||
| 	buf.SetPos(15, 30) | ||||
| 	buf.done = 1 | ||||
|  | ||||
| 	_, err := buf.Write([]byte{'a', 'b', 'c', 'd'}) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestWriteBytes(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		tail  int64 | ||||
| 		head  int64 | ||||
| 		bytes []byte | ||||
| 		want  []byte | ||||
| 		start int | ||||
| 		desc  string | ||||
| 	}{ | ||||
| 		{tail: 0, head: 0, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'a', 'b', 'c', 'd', 0, 0, 0, 0}, desc: "0,4 OK"}, | ||||
| 		{tail: 6, head: 6, bytes: []byte{'a', 'b', 'c', 'd'}, want: []byte{'c', 'd', 0, 0, 0, 0, 'a', 'b'}, desc: "6,2 OK wrapped"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, tt := range tests { | ||||
| 		buf := NewWriter(8, 4) | ||||
| 		buf.SetPos(tt.tail, tt.head) | ||||
| 		n := buf.writeBytes(tt.bytes) | ||||
|  | ||||
| 		require.Equal(t, tt.want, buf.buf, "Buffer mistmatch [i:%d] %s", i, tt.desc) | ||||
| 		require.Equal(t, len(tt.bytes), n) | ||||
| 	} | ||||
|  | ||||
| } | ||||
| @@ -1,516 +0,0 @@ | ||||
| package clients | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/rs/xid" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/server/internal/circ" | ||||
| 	"github.com/mochi-co/mqtt/server/internal/packets" | ||||
| 	"github.com/mochi-co/mqtt/server/internal/topics" | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/system" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// defaultKeepalive is the default connection keepalive value in seconds. | ||||
| 	defaultKeepalive uint16 = 10 | ||||
|  | ||||
| 	ErrConnectionClosed = errors.New("Connection not open") | ||||
| ) | ||||
|  | ||||
| // Clients contains a map of the clients known by the broker. | ||||
| type Clients struct { | ||||
| 	sync.RWMutex | ||||
| 	internal map[string]*Client // clients known by the broker, keyed on client id. | ||||
| } | ||||
|  | ||||
| // New returns an instance of Clients. | ||||
| func New() *Clients { | ||||
| 	return &Clients{ | ||||
| 		internal: make(map[string]*Client), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Add adds a new client to the clients map, keyed on client id. | ||||
| func (cl *Clients) Add(val *Client) { | ||||
| 	cl.Lock() | ||||
| 	cl.internal[val.ID] = val | ||||
| 	cl.Unlock() | ||||
| } | ||||
|  | ||||
| // Get returns the value of a client if it exists. | ||||
| func (cl *Clients) Get(id string) (*Client, bool) { | ||||
| 	cl.RLock() | ||||
| 	val, ok := cl.internal[id] | ||||
| 	cl.RUnlock() | ||||
| 	return val, ok | ||||
| } | ||||
|  | ||||
| // Len returns the length of the clients map. | ||||
| func (cl *Clients) Len() int { | ||||
| 	cl.RLock() | ||||
| 	val := len(cl.internal) | ||||
| 	cl.RUnlock() | ||||
| 	return val | ||||
| } | ||||
|  | ||||
| // Delete removes a client from the internal map. | ||||
| func (cl *Clients) Delete(id string) { | ||||
| 	cl.Lock() | ||||
| 	delete(cl.internal, id) | ||||
| 	cl.Unlock() | ||||
| } | ||||
|  | ||||
| // GetByListener returns clients matching a listener id. | ||||
| func (cl *Clients) GetByListener(id string) []*Client { | ||||
| 	clients := make([]*Client, 0, cl.Len()) | ||||
| 	cl.RLock() | ||||
| 	for _, v := range cl.internal { | ||||
| 		if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 { | ||||
| 			clients = append(clients, v) | ||||
| 		} | ||||
| 	} | ||||
| 	cl.RUnlock() | ||||
| 	return clients | ||||
| } | ||||
|  | ||||
| // Client contains information about a client known by the broker. | ||||
| type Client struct { | ||||
| 	sync.RWMutex | ||||
| 	conn          net.Conn             // the net.Conn used to establish the connection. | ||||
| 	r             *circ.Reader         // a reader for reading incoming bytes. | ||||
| 	w             *circ.Writer         // a writer for writing outgoing bytes. | ||||
| 	ID            string               // the client id. | ||||
| 	AC            auth.Controller      // an auth controller inherited from the listener. | ||||
| 	Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains. | ||||
| 	Listener      string               // the id of the listener the client is connected to. | ||||
| 	Inflight      Inflight             // a map of in-flight qos messages. | ||||
| 	Username      []byte               // the username the client authenticated with. | ||||
| 	keepalive     uint16               // the number of seconds the connection can wait. | ||||
| 	cleanSession  bool                 // indicates if the client expects a clean-session. | ||||
| 	packetID      uint32               // the current highest packetID. | ||||
| 	LWT           LWT                  // the last will and testament for the client. | ||||
| 	State         State                // the operational state of the client. | ||||
| 	system        *system.Info         // pointers to server system info. | ||||
| } | ||||
|  | ||||
| // State tracks the state of the client. | ||||
| type State struct { | ||||
| 	Done    int64           // atomic counter which indicates that the client has closed. | ||||
| 	started *sync.WaitGroup // tracks the goroutines which have been started. | ||||
| 	endedW  *sync.WaitGroup // tracks when the writer has ended. | ||||
| 	endedR  *sync.WaitGroup // tracks when the reader has ended. | ||||
| 	endOnce sync.Once       // only end once. | ||||
| } | ||||
|  | ||||
| // NewClient returns a new instance of Client. | ||||
| func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client { | ||||
| 	cl := &Client{ | ||||
| 		conn:      c, | ||||
| 		r:         r, | ||||
| 		w:         w, | ||||
| 		system:    s, | ||||
| 		keepalive: defaultKeepalive, | ||||
| 		Inflight: Inflight{ | ||||
| 			internal: make(map[uint16]InflightMessage), | ||||
| 		}, | ||||
| 		Subscriptions: make(map[string]byte), | ||||
| 		State: State{ | ||||
| 			started: new(sync.WaitGroup), | ||||
| 			endedW:  new(sync.WaitGroup), | ||||
| 			endedR:  new(sync.WaitGroup), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.keepalive) | ||||
|  | ||||
| 	return cl | ||||
| } | ||||
|  | ||||
| // NewClientStub returns an instance of Client with basic initializations. This | ||||
| // method is typically called by the persistence restoration system. | ||||
| func NewClientStub(s *system.Info) *Client { | ||||
| 	return &Client{ | ||||
| 		Inflight: Inflight{ | ||||
| 			internal: make(map[uint16]InflightMessage), | ||||
| 		}, | ||||
| 		Subscriptions: make(map[string]byte), | ||||
| 		State: State{ | ||||
| 			Done: 1, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Identify sets the identification values of a client instance. | ||||
| func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) { | ||||
| 	cl.Listener = lid | ||||
| 	cl.AC = ac | ||||
|  | ||||
| 	cl.ID = pk.ClientIdentifier | ||||
| 	if cl.ID == "" { | ||||
| 		cl.ID = xid.New().String() | ||||
| 	} | ||||
|  | ||||
| 	cl.r.ID = cl.ID + " READER" | ||||
| 	cl.w.ID = cl.ID + " WRITER" | ||||
|  | ||||
| 	cl.Username = pk.Username | ||||
| 	cl.cleanSession = pk.CleanSession | ||||
| 	cl.keepalive = pk.Keepalive | ||||
|  | ||||
| 	if pk.WillFlag { | ||||
| 		cl.LWT = LWT{ | ||||
| 			Topic:   pk.WillTopic, | ||||
| 			Message: pk.WillMessage, | ||||
| 			Qos:     pk.WillQos, | ||||
| 			Retain:  pk.WillRetain, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cl.refreshDeadline(cl.keepalive) | ||||
| } | ||||
|  | ||||
| // refreshDeadline refreshes the read/write deadline for the net.Conn connection. | ||||
| func (cl *Client) refreshDeadline(keepalive uint16) { | ||||
| 	if cl.conn != nil { | ||||
| 		var expiry time.Time // Nil time can be used to disable deadline if keepalive = 0 | ||||
| 		if keepalive > 0 { | ||||
| 			expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) | ||||
| 		} | ||||
| 		cl.conn.SetDeadline(expiry) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NextPacketID returns the next packet id for a client, looping back to 0 | ||||
| // if the maximum ID has been reached. | ||||
| func (cl *Client) NextPacketID() uint32 { | ||||
| 	i := atomic.LoadUint32(&cl.packetID) | ||||
| 	if i == uint32(65535) || i == uint32(0) { | ||||
| 		atomic.StoreUint32(&cl.packetID, 1) | ||||
| 		return 1 | ||||
| 	} | ||||
|  | ||||
| 	return atomic.AddUint32(&cl.packetID, 1) | ||||
| } | ||||
|  | ||||
| // NoteSubscription makes a note of a subscription for the client. | ||||
| func (cl *Client) NoteSubscription(filter string, qos byte) { | ||||
| 	cl.Lock() | ||||
| 	cl.Subscriptions[filter] = qos | ||||
| 	cl.Unlock() | ||||
| } | ||||
|  | ||||
| // ForgetSubscription forgests a subscription note for the client. | ||||
| func (cl *Client) ForgetSubscription(filter string) { | ||||
| 	cl.Lock() | ||||
| 	delete(cl.Subscriptions, filter) | ||||
| 	cl.Unlock() | ||||
| } | ||||
|  | ||||
| // Start begins the client goroutines reading and writing packets. | ||||
| func (cl *Client) Start() { | ||||
| 	cl.State.started.Add(2) | ||||
|  | ||||
| 	go func() { | ||||
| 		cl.State.started.Done() | ||||
| 		cl.w.WriteTo(cl.conn) | ||||
| 		cl.State.endedW.Done() | ||||
| 		cl.Stop() | ||||
| 	}() | ||||
| 	cl.State.endedW.Add(1) | ||||
|  | ||||
| 	go func() { | ||||
| 		cl.State.started.Done() | ||||
| 		cl.r.ReadFrom(cl.conn) | ||||
| 		cl.State.endedR.Done() | ||||
| 		cl.Stop() | ||||
| 	}() | ||||
| 	cl.State.endedR.Add(1) | ||||
|  | ||||
| 	cl.State.started.Wait() | ||||
| } | ||||
|  | ||||
| // Stop instructs the client to shut down all processing goroutines and disconnect. | ||||
| func (cl *Client) Stop() { | ||||
| 	if atomic.LoadInt64(&cl.State.Done) == 1 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	cl.State.endOnce.Do(func() { | ||||
| 		cl.r.Stop() | ||||
| 		cl.w.Stop() | ||||
| 		cl.State.endedW.Wait() | ||||
|  | ||||
| 		cl.conn.Close() | ||||
|  | ||||
| 		cl.State.endedR.Wait() | ||||
| 		atomic.StoreInt64(&cl.State.Done, 1) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // readFixedHeader reads in the values of the next packet's fixed header. | ||||
| func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error { | ||||
| 	p, err := cl.r.Read(1) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = fh.Decode(p[0]) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// The remaining length value can be up to 5 bytes. Read through each byte | ||||
| 	// looking for continue values, and if found increase the read. Otherwise | ||||
| 	// decode the bytes that were legit. | ||||
| 	buf := make([]byte, 0, 6) | ||||
| 	i := 1 | ||||
| 	n := 2 | ||||
| 	for ; n < 6; n++ { | ||||
| 		p, err = cl.r.Read(n) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		buf = append(buf, p[i]) | ||||
|  | ||||
| 		// If it's not a continuation flag, end here. | ||||
| 		if p[i] < 128 { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// If i has reached 4 without a length terminator, return a protocol violation. | ||||
| 		i++ | ||||
| 		if i == 4 { | ||||
| 			return packets.ErrOversizedLengthIndicator | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Calculate and store the remaining length of the packet payload. | ||||
| 	rem, _ := binary.Uvarint(buf) | ||||
| 	fh.Remaining = int(rem) | ||||
|  | ||||
| 	// Having successfully read n bytes, commit the tail forward. | ||||
| 	cl.r.CommitTail(n) | ||||
| 	atomic.AddInt64(&cl.system.BytesRecv, int64(n)) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Read loops forever reading new packets from a client connection until | ||||
| // an error is encountered (or the connection is closed). | ||||
| func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error { | ||||
| 	for { | ||||
| 		if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		cl.refreshDeadline(cl.keepalive) | ||||
| 		fh := new(packets.FixedHeader) | ||||
| 		err := cl.ReadFixedHeader(fh) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		pk, err := cl.ReadPacket(fh) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = packetHandler(cl, pk) // Process inbound packet. | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadPacket reads the remaining buffer into an MQTT packet. | ||||
| func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) { | ||||
| 	atomic.AddInt64(&cl.system.MessagesRecv, 1) | ||||
|  | ||||
| 	pk.FixedHeader = *fh | ||||
| 	if pk.FixedHeader.Remaining == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p, err := cl.r.Read(pk.FixedHeader.Remaining) | ||||
| 	if err != nil { | ||||
| 		return pk, err | ||||
| 	} | ||||
| 	atomic.AddInt64(&cl.system.BytesRecv, int64(len(p))) | ||||
|  | ||||
| 	// Decode the remaining packet values using a fresh copy of the bytes, | ||||
| 	// otherwise the next packet will change the data of this one. | ||||
| 	px := append([]byte{}, p[:]...) | ||||
|  | ||||
| 	switch pk.FixedHeader.Type { | ||||
| 	case packets.Connect: | ||||
| 		err = pk.ConnectDecode(px) | ||||
| 	case packets.Connack: | ||||
| 		err = pk.ConnackDecode(px) | ||||
| 	case packets.Publish: | ||||
| 		err = pk.PublishDecode(px) | ||||
| 		if err == nil { | ||||
| 			atomic.AddInt64(&cl.system.PublishRecv, 1) | ||||
| 		} | ||||
| 	case packets.Puback: | ||||
| 		err = pk.PubackDecode(px) | ||||
| 	case packets.Pubrec: | ||||
| 		err = pk.PubrecDecode(px) | ||||
| 	case packets.Pubrel: | ||||
| 		err = pk.PubrelDecode(px) | ||||
| 	case packets.Pubcomp: | ||||
| 		err = pk.PubcompDecode(px) | ||||
| 	case packets.Subscribe: | ||||
| 		err = pk.SubscribeDecode(px) | ||||
| 	case packets.Suback: | ||||
| 		err = pk.SubackDecode(px) | ||||
| 	case packets.Unsubscribe: | ||||
| 		err = pk.UnsubscribeDecode(px) | ||||
| 	case packets.Unsuback: | ||||
| 		err = pk.UnsubackDecode(px) | ||||
| 	case packets.Pingreq: | ||||
| 	case packets.Pingresp: | ||||
| 	case packets.Disconnect: | ||||
| 	default: | ||||
| 		err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type) | ||||
| 	} | ||||
|  | ||||
| 	cl.r.CommitTail(pk.FixedHeader.Remaining) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // WritePacket encodes and writes a packet to the client. | ||||
| func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) { | ||||
| 	if atomic.LoadInt64(&cl.State.Done) == 1 { | ||||
| 		return 0, ErrConnectionClosed | ||||
| 	} | ||||
|  | ||||
| 	cl.w.Mu.Lock() | ||||
| 	defer cl.w.Mu.Unlock() | ||||
|  | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	switch pk.FixedHeader.Type { | ||||
| 	case packets.Connect: | ||||
| 		err = pk.ConnectEncode(buf) | ||||
| 	case packets.Connack: | ||||
| 		err = pk.ConnackEncode(buf) | ||||
| 	case packets.Publish: | ||||
| 		err = pk.PublishEncode(buf) | ||||
| 		if err == nil { | ||||
| 			atomic.AddInt64(&cl.system.PublishSent, 1) | ||||
| 		} | ||||
| 	case packets.Puback: | ||||
| 		err = pk.PubackEncode(buf) | ||||
| 	case packets.Pubrec: | ||||
| 		err = pk.PubrecEncode(buf) | ||||
| 	case packets.Pubrel: | ||||
| 		err = pk.PubrelEncode(buf) | ||||
| 	case packets.Pubcomp: | ||||
| 		err = pk.PubcompEncode(buf) | ||||
| 	case packets.Subscribe: | ||||
| 		err = pk.SubscribeEncode(buf) | ||||
| 	case packets.Suback: | ||||
| 		err = pk.SubackEncode(buf) | ||||
| 	case packets.Unsubscribe: | ||||
| 		err = pk.UnsubscribeEncode(buf) | ||||
| 	case packets.Unsuback: | ||||
| 		err = pk.UnsubackEncode(buf) | ||||
| 	case packets.Pingreq: | ||||
| 		err = pk.PingreqEncode(buf) | ||||
| 	case packets.Pingresp: | ||||
| 		err = pk.PingrespEncode(buf) | ||||
| 	case packets.Disconnect: | ||||
| 		err = pk.DisconnectEncode(buf) | ||||
| 	default: | ||||
| 		err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Write the packet bytes to the client byte buffer. | ||||
| 	n, err = cl.w.Write(buf.Bytes()) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	atomic.AddInt64(&cl.system.BytesSent, int64(n)) | ||||
| 	atomic.AddInt64(&cl.system.MessagesSent, 1) | ||||
|  | ||||
| 	cl.refreshDeadline(cl.keepalive) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // LWT contains the last will and testament details for a client connection. | ||||
| type LWT struct { | ||||
| 	Topic   string // the topic the will message shall be sent to. | ||||
| 	Message []byte // the message that shall be sent when the client disconnects. | ||||
| 	Qos     byte   // the quality of service desired. | ||||
| 	Retain  bool   // indicates whether the will message should be retained | ||||
| } | ||||
|  | ||||
| // InflightMessage contains data about a packet which is currently in-flight. | ||||
| type InflightMessage struct { | ||||
| 	Packet  packets.Packet // the packet currently in-flight. | ||||
| 	Sent    int64          // the last time the message was sent (for retries) in unixtime. | ||||
| 	Resends int            // the number of times the message was attempted to be sent. | ||||
| } | ||||
|  | ||||
| // Inflight is a map of InflightMessage keyed on packet id. | ||||
| type Inflight struct { | ||||
| 	sync.RWMutex | ||||
| 	internal map[uint16]InflightMessage // internal contains the inflight messages. | ||||
| } | ||||
|  | ||||
| // Set stores the packet of an Inflight message, keyed on message id. Returns | ||||
| // true if the inflight message was new. | ||||
| func (i *Inflight) Set(key uint16, in InflightMessage) bool { | ||||
| 	i.Lock() | ||||
| 	_, ok := i.internal[key] | ||||
| 	i.internal[key] = in | ||||
| 	i.Unlock() | ||||
| 	return !ok | ||||
| } | ||||
|  | ||||
| // Get returns the value of an in-flight message if it exists. | ||||
| func (i *Inflight) Get(key uint16) (InflightMessage, bool) { | ||||
| 	i.RLock() | ||||
| 	val, ok := i.internal[key] | ||||
| 	i.RUnlock() | ||||
| 	return val, ok | ||||
| } | ||||
|  | ||||
| // Len returns the size of the in-flight messages map. | ||||
| func (i *Inflight) Len() int { | ||||
| 	i.RLock() | ||||
| 	v := len(i.internal) | ||||
| 	i.RUnlock() | ||||
| 	return v | ||||
| } | ||||
|  | ||||
| // GetAll returns all the in-flight messages. | ||||
| func (i *Inflight) GetAll() map[uint16]InflightMessage { | ||||
| 	i.RLock() | ||||
| 	defer i.RUnlock() | ||||
| 	return i.internal | ||||
| } | ||||
|  | ||||
| // Delete removes an in-flight message from the map. Returns true if the | ||||
| // message existed. | ||||
| func (i *Inflight) Delete(key uint16) bool { | ||||
| 	i.Lock() | ||||
| 	_, ok := i.internal[key] | ||||
| 	delete(i.internal, key) | ||||
| 	i.Unlock() | ||||
| 	return ok | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,383 +0,0 @@ | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestBytesToString(t *testing.T) { | ||||
| 	b := []byte{'a', 'b', 'c'} | ||||
| 	require.Equal(t, "abc", bytesToString(b)) | ||||
| } | ||||
|  | ||||
| func BenchmarkBytesToString(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		bytesToString([]byte{'a', 'b', 'c'}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeString(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     []string | ||||
| 		offset     int | ||||
| 		shouldFail bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			offset:   0, | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   []string{"a/b/c/d", "a"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset: 14, | ||||
| 			rawBytes: []byte{ | ||||
| 				byte(Connect << 4), 17, // Fixed header | ||||
| 				0, 6, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name | ||||
| 				3,     // Protocol Version | ||||
| 				0,     // Packet Flags | ||||
| 				0, 30, // Keepalive | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'h', 'e', 'y', // Client ID "zen"}, | ||||
| 			}, | ||||
| 			result: []string{"hey"}, | ||||
| 		}, | ||||
|  | ||||
| 		{ | ||||
| 			offset:   2, | ||||
| 			rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, | ||||
| 			result:   []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:   0, | ||||
| 			rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, | ||||
| 			result:   []string{"x/y/z", "!@#$%^&"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     0, | ||||
| 			rawBytes:   []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, | ||||
| 			result:     []string{"a/b/c/d", "z"}, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     5, | ||||
| 			rawBytes:   []byte{0, 7, 97, 47, 98, 47, 'x'}, | ||||
| 			result:     []string{"a/b/c/d", "x"}, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset:     9, | ||||
| 			rawBytes:   []byte{0, 7, 97, 47, 98, 47, 'y'}, | ||||
| 			result:     []string{"a/b/c/d", "y"}, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			offset: 17, | ||||
| 			rawBytes: []byte{ | ||||
| 				byte(Connect << 4), 0, // Fixed header | ||||
| 				0, 4, // Protocol Name - MSB+LSB | ||||
| 				'M', 'Q', 'T', 'T', // Protocol Name | ||||
| 				4,     // Protocol Version | ||||
| 				0,     // Flags | ||||
| 				0, 20, // Keepalive | ||||
| 				0, 3, // Client ID - MSB+LSB | ||||
| 				'z', 'e', 'n', // Client ID "zen" | ||||
| 				0, 6, // Will Topic - MSB+LSB | ||||
| 				'l', | ||||
| 			}, | ||||
| 			result:     []string{"lwt"}, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		result, _, err := decodeString(wanted.rawBytes, wanted.offset) | ||||
| 		if wanted.shouldFail { | ||||
| 			require.Error(t, err, "Expected error decoding string [i:%d]", i) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		require.NoError(t, err, "Error decoding string [i:%d]", i) | ||||
| 		require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkDecodeString(b *testing.B) { | ||||
| 	in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		decodeString(in, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeBytes(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     []uint8 | ||||
| 		next       int | ||||
| 		offset     int | ||||
| 		shouldFail bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session) | ||||
| 			result:   []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), | ||||
| 			next:     6, | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start | ||||
| 			result:   []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), | ||||
| 			next:     6, | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 81}, | ||||
| 			result:     []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), | ||||
| 			offset:     0, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 81}, | ||||
| 			result:     []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), | ||||
| 			offset:     8, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 81}, | ||||
| 			result:     []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), | ||||
| 			offset:     0, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) | ||||
| 		if wanted.shouldFail { | ||||
| 			require.Error(t, err, "Expected error decoding bytes [i:%d]", i) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		require.NoError(t, err, "Error decoding bytes [i:%d]", i) | ||||
| 		require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkDecodeBytes(b *testing.B) { | ||||
| 	in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		decodeBytes(in, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeByte(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     uint8 | ||||
| 		offset     int | ||||
| 		shouldFail bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes | ||||
| 			result:   uint8(0x00), | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x04), | ||||
| 			offset:   1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x4d), | ||||
| 			offset:   2, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 4, 77, 81, 84, 84}, | ||||
| 			result:   uint8(0x51), | ||||
| 			offset:   3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 4, 77, 80, 82, 84}, | ||||
| 			result:     uint8(0x00), | ||||
| 			offset:     8, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) | ||||
| 		if wanted.shouldFail { | ||||
| 			require.Error(t, err, "Expected error decoding byte [i:%d]", i) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		require.NoError(t, err, "Error decoding byte [i:%d]", i) | ||||
| 		require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) | ||||
| 		require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkDecodeByte(b *testing.B) { | ||||
| 	in := []byte{0, 4, 77, 81, 84, 84} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		decodeByte(in, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeUint16(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     uint16 | ||||
| 		offset     int | ||||
| 		shouldFail bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   uint16(0x07), | ||||
| 			offset:   0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, | ||||
| 			result:   uint16(0x761), | ||||
| 			offset:   1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0, 7, 255, 47}, | ||||
| 			result:     uint16(0x761), | ||||
| 			offset:     8, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) | ||||
| 		if wanted.shouldFail { | ||||
| 			require.Error(t, err, "Expected error decoding uint16 [i:%d]", i) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		require.NoError(t, err, "Error decoding uint16 [i:%d]", i) | ||||
| 		require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) | ||||
| 		require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkDecodeUint16(b *testing.B) { | ||||
| 	in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		decodeUint16(in, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecodeByteBool(t *testing.T) { | ||||
| 	expect := []struct { | ||||
| 		rawBytes   []byte | ||||
| 		result     bool | ||||
| 		offset     int | ||||
| 		shouldFail bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			rawBytes: []byte{0x00, 0x00}, | ||||
| 			result:   false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes: []byte{0x01, 0x00}, | ||||
| 			result:   true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			rawBytes:   []byte{0x01, 0x00}, | ||||
| 			offset:     5, | ||||
| 			shouldFail: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range expect { | ||||
| 		result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) | ||||
| 		if wanted.shouldFail { | ||||
| 			require.Error(t, err, "Expected error decoding byte bool [i:%d]", i) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		require.NoError(t, err, "Error decoding byte bool [i:%d]", i) | ||||
| 		require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) | ||||
| 		require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkDecodeByteBool(b *testing.B) { | ||||
| 	in := []byte{0x00, 0x00} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		decodeByteBool(in, 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeBool(t *testing.T) { | ||||
| 	result := encodeBool(true) | ||||
| 	require.Equal(t, byte(1), result, "Incorrect encoded value; not true") | ||||
|  | ||||
| 	result = encodeBool(false) | ||||
| 	require.Equal(t, byte(0), result, "Incorrect encoded value; not false") | ||||
|  | ||||
| 	// Check failure. | ||||
| 	result = encodeBool(false) | ||||
| 	require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value") | ||||
| } | ||||
|  | ||||
| func BenchmarkEncodeBool(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		encodeBool(true) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeBytes(t *testing.T) { | ||||
| 	result := encodeBytes([]byte("testing")) | ||||
| 	require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value") | ||||
|  | ||||
| 	result = encodeBytes([]byte("testing")) | ||||
| 	require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value") | ||||
| } | ||||
|  | ||||
| func BenchmarkEncodeBytes(b *testing.B) { | ||||
| 	bb := []byte("testing") | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		encodeBytes(bb) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeUint16(t *testing.T) { | ||||
| 	result := encodeUint16(0) | ||||
| 	require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0") | ||||
|  | ||||
| 	result = encodeUint16(32767) | ||||
| 	require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767") | ||||
|  | ||||
| 	result = encodeUint16(65535) | ||||
| 	require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535") | ||||
| } | ||||
|  | ||||
| func BenchmarkEncodeUint16(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		encodeUint16(32767) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeString(t *testing.T) { | ||||
| 	result := encodeString("testing") | ||||
| 	require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing") | ||||
|  | ||||
| 	result = encodeString("") | ||||
| 	require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null") | ||||
|  | ||||
| 	result = encodeString("a") | ||||
| 	require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a") | ||||
|  | ||||
| 	result = encodeString("b") | ||||
| 	require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b") | ||||
|  | ||||
| } | ||||
|  | ||||
| func BenchmarkEncodeString(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		encodeString("benchmarking") | ||||
| 	} | ||||
| } | ||||
| @@ -1,59 +0,0 @@ | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| ) | ||||
|  | ||||
| // FixedHeader contains the values of the fixed header portion of the MQTT packet. | ||||
| type FixedHeader struct { | ||||
| 	Type      byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). | ||||
| 	Dup       bool // indicates if the packet was already sent at an earlier time. | ||||
| 	Qos       byte // indicates the quality of service expected. | ||||
| 	Retain    bool // whether the message should be retained. | ||||
| 	Remaining int  // the number of remaining bytes in the payload. | ||||
| } | ||||
|  | ||||
| // Encode encodes the FixedHeader and returns a bytes buffer. | ||||
| func (fh *FixedHeader) Encode(buf *bytes.Buffer) { | ||||
| 	buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) | ||||
| 	encodeLength(buf, fh.Remaining) | ||||
| } | ||||
|  | ||||
| // decode extracts the specification bits from the header byte. | ||||
| func (fh *FixedHeader) Decode(headerByte byte) error { | ||||
| 	fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes. | ||||
|  | ||||
| 	switch fh.Type { | ||||
| 	case Publish: | ||||
| 		fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate. | ||||
| 		fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag. | ||||
| 		fh.Retain = headerByte&0x01 > 0   // Extract retain flag. | ||||
| 	case Pubrel: | ||||
| 		fh.Qos = (headerByte >> 1) & 0x03 | ||||
| 	case Subscribe: | ||||
| 		fh.Qos = (headerByte >> 1) & 0x03 | ||||
| 	case Unsubscribe: | ||||
| 		fh.Qos = (headerByte >> 1) & 0x03 | ||||
| 	default: | ||||
| 		if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 { | ||||
| 			return ErrInvalidFlags | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // encodeLength writes length bits for the header. | ||||
| func encodeLength(buf *bytes.Buffer, length int) { | ||||
| 	for { | ||||
| 		digit := byte(length % 128) | ||||
| 		length /= 128 | ||||
| 		if length > 0 { | ||||
| 			digit |= 0x80 | ||||
| 		} | ||||
| 		buf.WriteByte(digit) | ||||
| 		if length == 0 { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -1,220 +0,0 @@ | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"math" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| type fixedHeaderTable struct { | ||||
| 	rawBytes    []byte | ||||
| 	header      FixedHeader | ||||
| 	packetError bool | ||||
| 	flagError   bool | ||||
| } | ||||
|  | ||||
| var fixedHeaderExpected = []fixedHeaderTable{ | ||||
| 	{ | ||||
| 		rawBytes: []byte{Connect << 4, 0x00}, | ||||
| 		header:   FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Connack << 4, 0x00}, | ||||
| 		header:   FixedHeader{Connack, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish << 4, 0x00}, | ||||
| 		header:   FixedHeader{Publish, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, false, 1, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, false, 1, true, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, false, 2, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, false, 2, true, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, | ||||
| 		header:   FixedHeader{Publish, true, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, true, 0, true, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, true, 1, true, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, | ||||
| 		header:   FixedHeader{Publish, true, 2, true, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Puback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Puback, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Pubrec << 4, 0x00}, | ||||
| 		header:   FixedHeader{Pubrec, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Pubrel, false, 1, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Pubcomp << 4, 0x00}, | ||||
| 		header:   FixedHeader{Pubcomp, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Subscribe, false, 1, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Suback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Suback, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, | ||||
| 		header:   FixedHeader{Unsubscribe, false, 1, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Unsuback << 4, 0x00}, | ||||
| 		header:   FixedHeader{Unsuback, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Pingreq << 4, 0x00}, | ||||
| 		header:   FixedHeader{Pingreq, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Pingresp << 4, 0x00}, | ||||
| 		header:   FixedHeader{Pingresp, false, 0, false, 0}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Disconnect << 4, 0x00}, | ||||
| 		header:   FixedHeader{Disconnect, false, 0, false, 0}, | ||||
| 	}, | ||||
|  | ||||
| 	// remaining length | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish << 4, 0x0a}, | ||||
| 		header:   FixedHeader{Publish, false, 0, false, 10}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish << 4, 0x80, 0x04}, | ||||
| 		header:   FixedHeader{Publish, false, 0, false, 512}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish << 4, 0xd2, 0x07}, | ||||
| 		header:   FixedHeader{Publish, false, 0, false, 978}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, | ||||
| 		header:   FixedHeader{Publish, false, 0, false, 20102}, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes:    []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, | ||||
| 		header:      FixedHeader{Publish, false, 0, false, 333333333}, | ||||
| 		packetError: true, | ||||
| 	}, | ||||
|  | ||||
| 	// Invalid flags for packet | ||||
| 	{ | ||||
| 		rawBytes:  []byte{Connect<<4 | 1<<3, 0x00}, | ||||
| 		header:    FixedHeader{Connect, true, 0, false, 0}, | ||||
| 		flagError: true, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes:  []byte{Connect<<4 | 1<<1, 0x00}, | ||||
| 		header:    FixedHeader{Connect, false, 1, false, 0}, | ||||
| 		flagError: true, | ||||
| 	}, | ||||
| 	{ | ||||
| 		rawBytes:  []byte{Connect<<4 | 1, 0x00}, | ||||
| 		header:    FixedHeader{Connect, false, 0, true, 0}, | ||||
| 		flagError: true, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| func TestFixedHeaderEncode(t *testing.T) { | ||||
| 	for i, wanted := range fixedHeaderExpected { | ||||
| 		buf := new(bytes.Buffer) | ||||
| 		wanted.header.Encode(buf) | ||||
| 		if wanted.flagError == false { | ||||
| 			require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes) | ||||
| 			require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkFixedHeaderEncode(b *testing.B) { | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		fixedHeaderExpected[0].header.Encode(buf) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestFixedHeaderDecode(t *testing.T) { | ||||
| 	for i, wanted := range fixedHeaderExpected { | ||||
| 		fh := new(FixedHeader) | ||||
| 		err := fh.Decode(wanted.rawBytes[0]) | ||||
| 		if wanted.flagError { | ||||
| 			require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes) | ||||
| 		} else { | ||||
| 			require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes) | ||||
| 			require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes) | ||||
| 			require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes) | ||||
| 			require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes) | ||||
| 			require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkFixedHeaderDecode(b *testing.B) { | ||||
| 	fh := new(FixedHeader) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		err := fh.Decode(fixedHeaderExpected[0].rawBytes[0]) | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncodeLength(t *testing.T) { | ||||
| 	tt := []struct { | ||||
| 		have int | ||||
| 		want []byte | ||||
| 	}{ | ||||
| 		{ | ||||
| 			120, | ||||
| 			[]byte{0x78}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			math.MaxInt64, | ||||
| 			[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, wanted := range tt { | ||||
| 		buf := new(bytes.Buffer) | ||||
| 		encodeLength(buf, wanted.have) | ||||
| 		require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkEncodeLength(b *testing.B) { | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		encodeLength(buf, 120) | ||||
| 	} | ||||
| } | ||||
| @@ -1,673 +0,0 @@ | ||||
| package packets | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| ) | ||||
|  | ||||
| // All of the valid packet types and their packet identifier. | ||||
| const ( | ||||
| 	Reserved    byte = iota | ||||
| 	Connect          // 1 | ||||
| 	Connack          // 2 | ||||
| 	Publish          // 3 | ||||
| 	Puback           // 4 | ||||
| 	Pubrec           // 5 | ||||
| 	Pubrel           // 6 | ||||
| 	Pubcomp          // 7 | ||||
| 	Subscribe        // 8 | ||||
| 	Suback           // 9 | ||||
| 	Unsubscribe      // 10 | ||||
| 	Unsuback         // 11 | ||||
| 	Pingreq          // 12 | ||||
| 	Pingresp         // 13 | ||||
| 	Disconnect       // 14 | ||||
|  | ||||
| 	Accepted                      byte = 0x00 | ||||
| 	Failed                        byte = 0xFF | ||||
| 	CodeConnectBadProtocolVersion byte = 0x01 | ||||
| 	CodeConnectBadClientID        byte = 0x02 | ||||
| 	CodeConnectServerUnavailable  byte = 0x03 | ||||
| 	CodeConnectBadAuthValues      byte = 0x04 | ||||
| 	CodeConnectNotAuthorised      byte = 0x05 | ||||
| 	CodeConnectNetworkError       byte = 0xFE | ||||
| 	CodeConnectProtocolViolation  byte = 0xFF | ||||
| 	ErrSubAckNetworkError         byte = 0x80 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// CONNECT | ||||
| 	ErrMalformedProtocolName    = errors.New("malformed packet: protocol name") | ||||
| 	ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version") | ||||
| 	ErrMalformedFlags           = errors.New("malformed packet: flags") | ||||
| 	ErrMalformedKeepalive       = errors.New("malformed packet: keepalive") | ||||
| 	ErrMalformedClientID        = errors.New("malformed packet: client id") | ||||
| 	ErrMalformedWillTopic       = errors.New("malformed packet: will topic") | ||||
| 	ErrMalformedWillMessage     = errors.New("malformed packet: will message") | ||||
| 	ErrMalformedUsername        = errors.New("malformed packet: username") | ||||
| 	ErrMalformedPassword        = errors.New("malformed packet: password") | ||||
|  | ||||
| 	// CONNACK | ||||
| 	ErrMalformedSessionPresent = errors.New("malformed packet: session present") | ||||
| 	ErrMalformedReturnCode     = errors.New("malformed packet: return code") | ||||
|  | ||||
| 	// PUBLISH | ||||
| 	ErrMalformedTopic    = errors.New("malformed packet: topic name") | ||||
| 	ErrMalformedPacketID = errors.New("malformed packet: packet id") | ||||
|  | ||||
| 	// SUBSCRIBE | ||||
| 	ErrMalformedQoS = errors.New("malformed packet: qos") | ||||
|  | ||||
| 	// PACKETS | ||||
| 	ErrProtocolViolation        = errors.New("protocol violation") | ||||
| 	ErrOffsetStrOutOfRange      = errors.New("offset string out of range") | ||||
| 	ErrOffsetBytesOutOfRange    = errors.New("offset bytes out of range") | ||||
| 	ErrOffsetByteOutOfRange     = errors.New("offset byte out of range") | ||||
| 	ErrOffsetBoolOutOfRange     = errors.New("offset bool out of range") | ||||
| 	ErrOffsetUintOutOfRange     = errors.New("offset uint out of range") | ||||
| 	ErrOffsetStrInvalidUTF8     = errors.New("offset string invalid utf8") | ||||
| 	ErrInvalidFlags             = errors.New("invalid flags set for packet") | ||||
| 	ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator") | ||||
| 	ErrMissingPacketID          = errors.New("missing packet id") | ||||
| 	ErrSurplusPacketID          = errors.New("surplus packet id") | ||||
| ) | ||||
|  | ||||
| // Packet is an MQTT packet. Instead of providing a packet interface and variant | ||||
| // packet structs, this is a single concrete packet type to cover all packet | ||||
| // types, which allows us to take advantage of various compiler optimizations. | ||||
| type Packet struct { | ||||
| 	FixedHeader FixedHeader | ||||
|  | ||||
| 	PacketID uint16 | ||||
|  | ||||
| 	// Connect | ||||
| 	ProtocolName     []byte | ||||
| 	ProtocolVersion  byte | ||||
| 	CleanSession     bool | ||||
| 	WillFlag         bool | ||||
| 	WillQos          byte | ||||
| 	WillRetain       bool | ||||
| 	UsernameFlag     bool | ||||
| 	PasswordFlag     bool | ||||
| 	ReservedBit      byte | ||||
| 	Keepalive        uint16 | ||||
| 	ClientIdentifier string | ||||
| 	WillTopic        string | ||||
| 	WillMessage      []byte | ||||
| 	Username         []byte | ||||
| 	Password         []byte | ||||
|  | ||||
| 	// Connack | ||||
| 	SessionPresent bool | ||||
| 	ReturnCode     byte | ||||
|  | ||||
| 	// Publish | ||||
| 	TopicName string | ||||
| 	Payload   []byte | ||||
|  | ||||
| 	// Subscribe, Unsubscribe | ||||
| 	Topics []string | ||||
| 	Qoss   []byte | ||||
|  | ||||
| 	ReturnCodes []byte // Suback | ||||
| } | ||||
|  | ||||
| // ConnectEncode encodes a connect packet. | ||||
| func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	protoName := encodeBytes(pk.ProtocolName) | ||||
| 	protoVersion := pk.ProtocolVersion | ||||
| 	flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7 | ||||
| 	keepalive := encodeUint16(pk.Keepalive) | ||||
| 	clientID := encodeString(pk.ClientIdentifier) | ||||
|  | ||||
| 	var willTopic, willFlag, usernameFlag, passwordFlag []byte | ||||
|  | ||||
| 	// If will flag is set, add topic and message. | ||||
| 	if pk.WillFlag { | ||||
| 		willTopic = encodeString(pk.WillTopic) | ||||
| 		willFlag = encodeBytes(pk.WillMessage) | ||||
| 	} | ||||
|  | ||||
| 	// If username flag is set, add username. | ||||
| 	if pk.UsernameFlag { | ||||
| 		usernameFlag = encodeBytes(pk.Username) | ||||
| 	} | ||||
|  | ||||
| 	// If password flag is set, add password. | ||||
| 	if pk.PasswordFlag { | ||||
| 		passwordFlag = encodeBytes(pk.Password) | ||||
| 	} | ||||
|  | ||||
| 	// Get a length for the connect header. This is not super pretty, but it works. | ||||
| 	pk.FixedHeader.Remaining = | ||||
| 		len(protoName) + 1 + 1 + len(keepalive) + len(clientID) + | ||||
| 			len(willTopic) + len(willFlag) + | ||||
| 			len(usernameFlag) + len(passwordFlag) | ||||
|  | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
|  | ||||
| 	// Eschew magic for readability. | ||||
| 	buf.Write(protoName) | ||||
| 	buf.WriteByte(protoVersion) | ||||
| 	buf.WriteByte(flag) | ||||
| 	buf.Write(keepalive) | ||||
| 	buf.Write(clientID) | ||||
| 	buf.Write(willTopic) | ||||
| 	buf.Write(willFlag) | ||||
| 	buf.Write(usernameFlag) | ||||
| 	buf.Write(passwordFlag) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ConnectDecode decodes a connect packet. | ||||
| func (pk *Packet) ConnectDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	// Unpack protocol name and version. | ||||
| 	pk.ProtocolName, offset, err = decodeBytes(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedProtocolName | ||||
| 	} | ||||
|  | ||||
| 	pk.ProtocolVersion, offset, err = decodeByte(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedProtocolVersion | ||||
| 	} | ||||
| 	// Unpack flags byte. | ||||
| 	flags, offset, err := decodeByte(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedFlags | ||||
| 	} | ||||
| 	pk.ReservedBit = 1 & flags | ||||
| 	pk.CleanSession = 1&(flags>>1) > 0 | ||||
| 	pk.WillFlag = 1&(flags>>2) > 0 | ||||
| 	pk.WillQos = 3 & (flags >> 3) // this one is not a bool | ||||
| 	pk.WillRetain = 1&(flags>>5) > 0 | ||||
| 	pk.PasswordFlag = 1&(flags>>6) > 0 | ||||
| 	pk.UsernameFlag = 1&(flags>>7) > 0 | ||||
|  | ||||
| 	// Get keepalive interval. | ||||
| 	pk.Keepalive, offset, err = decodeUint16(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedKeepalive | ||||
| 	} | ||||
|  | ||||
| 	// Get client ID. | ||||
| 	pk.ClientIdentifier, offset, err = decodeString(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedClientID | ||||
| 	} | ||||
|  | ||||
| 	// Get Last Will and Testament topic and message if applicable. | ||||
| 	if pk.WillFlag { | ||||
| 		pk.WillTopic, offset, err = decodeString(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedWillTopic | ||||
| 		} | ||||
|  | ||||
| 		pk.WillMessage, offset, err = decodeBytes(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedWillMessage | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Get username and password if applicable. | ||||
| 	if pk.UsernameFlag { | ||||
| 		pk.Username, offset, err = decodeBytes(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedUsername | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if pk.PasswordFlag { | ||||
| 		pk.Password, offset, err = decodeBytes(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedPassword | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
|  | ||||
| } | ||||
|  | ||||
| // ConnectValidate ensures the connect packet is compliant. | ||||
| func (pk *Packet) ConnectValidate() (b byte, err error) { | ||||
|  | ||||
| 	// End if protocol name is bad. | ||||
| 	if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 && | ||||
| 		bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 { | ||||
| 		return CodeConnectProtocolViolation, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if protocol version is bad. | ||||
| 	if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) || | ||||
| 		(bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) { | ||||
| 		return CodeConnectBadProtocolVersion, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if reserved bit is not 0. | ||||
| 	if pk.ReservedBit != 0 { | ||||
| 		return CodeConnectProtocolViolation, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if ClientID is too long. | ||||
| 	if len(pk.ClientIdentifier) > 65535 { | ||||
| 		return CodeConnectProtocolViolation, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if password flag is set without a username. | ||||
| 	if pk.PasswordFlag && !pk.UsernameFlag { | ||||
| 		return CodeConnectProtocolViolation, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if Username or Password is too long. | ||||
| 	if len(pk.Username) > 65535 || len(pk.Password) > 65535 { | ||||
| 		return CodeConnectProtocolViolation, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	// End if client id isn't set and clean session is false. | ||||
| 	if !pk.CleanSession && len(pk.ClientIdentifier) == 0 { | ||||
| 		return CodeConnectBadClientID, ErrProtocolViolation | ||||
| 	} | ||||
|  | ||||
| 	return Accepted, nil | ||||
| } | ||||
|  | ||||
| // ConnackEncode encodes a Connack packet. | ||||
| func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.WriteByte(encodeBool(pk.SessionPresent)) | ||||
| 	buf.WriteByte(pk.ReturnCode) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ConnackDecode decodes a Connack packet. | ||||
| func (pk *Packet) ConnackDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	pk.SessionPresent, offset, err = decodeByteBool(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedSessionPresent | ||||
| 	} | ||||
|  | ||||
| 	pk.ReturnCode, offset, err = decodeByte(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedReturnCode | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // DisconnectEncode encodes a Disconnect packet. | ||||
| func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PingreqEncode encodes a Pingreq packet. | ||||
| func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PingrespEncode encodes a Pingresp packet. | ||||
| func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubackEncode encodes a Puback packet. | ||||
| func (pk *Packet) PubackEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(encodeUint16(pk.PacketID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubackDecode decodes a Puback packet. | ||||
| func (pk *Packet) PubackDecode(buf []byte) error { | ||||
| 	var err error | ||||
| 	pk.PacketID, _, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubcompEncode encodes a Pubcomp packet. | ||||
| func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(encodeUint16(pk.PacketID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubcompDecode decodes a Pubcomp packet. | ||||
| func (pk *Packet) PubcompDecode(buf []byte) error { | ||||
| 	var err error | ||||
| 	pk.PacketID, _, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PublishEncode encodes a Publish packet. | ||||
| func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { | ||||
| 	topicName := encodeString(pk.TopicName) | ||||
| 	var packetID []byte | ||||
|  | ||||
| 	// Add PacketID if QOS is set. | ||||
| 	// [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. | ||||
| 	if pk.FixedHeader.Qos > 0 { | ||||
|  | ||||
| 		// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 		if pk.PacketID == 0 { | ||||
| 			return ErrMissingPacketID | ||||
| 		} | ||||
|  | ||||
| 		packetID = encodeUint16(pk.PacketID) | ||||
| 	} | ||||
|  | ||||
| 	pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload) | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(topicName) | ||||
| 	buf.Write(packetID) | ||||
| 	buf.Write(pk.Payload) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PublishDecode extracts the data values from the packet. | ||||
| func (pk *Packet) PublishDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	pk.TopicName, offset, err = decodeString(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedTopic | ||||
| 	} | ||||
|  | ||||
| 	// If QOS decode Packet ID. | ||||
| 	if pk.FixedHeader.Qos > 0 { | ||||
| 		pk.PacketID, offset, err = decodeUint16(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedPacketID | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	pk.Payload = buf[offset:] | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PublishCopy creates a new instance of Publish packet bearing the | ||||
| // same payload and destination topic, but with an empty header for | ||||
| // inheriting new QoS flags, etc. | ||||
| func (pk *Packet) PublishCopy() Packet { | ||||
| 	return Packet{ | ||||
| 		FixedHeader: FixedHeader{ | ||||
| 			Type:   Publish, | ||||
| 			Retain: pk.FixedHeader.Retain, | ||||
| 		}, | ||||
| 		TopicName: pk.TopicName, | ||||
| 		Payload:   pk.Payload, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // PublishValidate validates a publish packet. | ||||
| func (pk *Packet) PublishValidate() (byte, error) { | ||||
|  | ||||
| 	// @SPEC [MQTT-2.3.1-1] | ||||
| 	// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 	if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { | ||||
| 		return Failed, ErrMissingPacketID | ||||
| 	} | ||||
|  | ||||
| 	// @SPEC [MQTT-2.3.1-5] | ||||
| 	// A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. | ||||
| 	if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 { | ||||
| 		return Failed, ErrSurplusPacketID | ||||
| 	} | ||||
|  | ||||
| 	return Accepted, nil | ||||
| } | ||||
|  | ||||
| // PubrecEncode encodes a Pubrec packet. | ||||
| func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(encodeUint16(pk.PacketID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubrecDecode decodes a Pubrec packet. | ||||
| func (pk *Packet) PubrecDecode(buf []byte) error { | ||||
| 	var err error | ||||
| 	pk.PacketID, _, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubrelEncode encodes a Pubrel packet. | ||||
| func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(encodeUint16(pk.PacketID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // PubrelDecode decodes a Pubrel packet. | ||||
| func (pk *Packet) PubrelDecode(buf []byte) error { | ||||
| 	var err error | ||||
| 	pk.PacketID, _, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SubackEncode encodes a Suback packet. | ||||
| func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { | ||||
| 	packetID := encodeUint16(pk.PacketID) | ||||
| 	pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length. | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
|  | ||||
| 	buf.Write(packetID)       // Encode Packet ID. | ||||
| 	buf.Write(pk.ReturnCodes) // Encode granted QOS flags. | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SubackDecode decodes a Suback packet. | ||||
| func (pk *Packet) SubackDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	// Get Packet ID. | ||||
| 	pk.PacketID, offset, err = decodeUint16(buf, offset) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
|  | ||||
| 	// Get Granted QOS flags. | ||||
| 	pk.ReturnCodes = buf[offset:] | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SubscribeEncode encodes a Subscribe packet. | ||||
| func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	// Add the Packet ID. | ||||
| 	// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 	if pk.PacketID == 0 { | ||||
| 		return ErrMissingPacketID | ||||
| 	} | ||||
|  | ||||
| 	packetID := encodeUint16(pk.PacketID) | ||||
|  | ||||
| 	// Count topics lengths and associated QOS flags. | ||||
| 	var topicsLen int | ||||
| 	for _, topic := range pk.Topics { | ||||
| 		topicsLen += len(encodeString(topic)) + 1 | ||||
| 	} | ||||
|  | ||||
| 	pk.FixedHeader.Remaining = len(packetID) + topicsLen | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(packetID) | ||||
|  | ||||
| 	// Add all provided topic names and associated QOS flags. | ||||
| 	for i, topic := range pk.Topics { | ||||
| 		buf.Write(encodeString(topic)) | ||||
| 		buf.WriteByte(pk.Qoss[i]) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SubscribeDecode decodes a Subscribe packet. | ||||
| func (pk *Packet) SubscribeDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	// Get the Packet ID. | ||||
| 	pk.PacketID, offset, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
|  | ||||
| 	// Keep decoding until there's no space left. | ||||
| 	for offset < len(buf) { | ||||
|  | ||||
| 		// Decode Topic Name. | ||||
| 		var topic string | ||||
| 		topic, offset, err = decodeString(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedTopic | ||||
| 		} | ||||
| 		pk.Topics = append(pk.Topics, topic) | ||||
|  | ||||
| 		// Decode QOS flag. | ||||
| 		var qos byte | ||||
| 		qos, offset, err = decodeByte(buf, offset) | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedQoS | ||||
| 		} | ||||
|  | ||||
| 		// Ensure QoS byte is within range. | ||||
| 		if !(qos >= 0 && qos <= 2) { | ||||
| 			//if !validateQoS(qos) { | ||||
| 			return ErrMalformedQoS | ||||
| 		} | ||||
|  | ||||
| 		pk.Qoss = append(pk.Qoss, qos) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SubscribeValidate ensures the packet is compliant. | ||||
| func (pk *Packet) SubscribeValidate() (byte, error) { | ||||
| 	// @SPEC [MQTT-2.3.1-1]. | ||||
| 	// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 	if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { | ||||
| 		return Failed, ErrMissingPacketID | ||||
| 	} | ||||
|  | ||||
| 	return Accepted, nil | ||||
| } | ||||
|  | ||||
| // UnsubackEncode encodes an Unsuback packet. | ||||
| func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { | ||||
| 	pk.FixedHeader.Remaining = 2 | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(encodeUint16(pk.PacketID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // UnsubackDecode decodes an Unsuback packet. | ||||
| func (pk *Packet) UnsubackDecode(buf []byte) error { | ||||
| 	var err error | ||||
| 	pk.PacketID, _, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // UnsubscribeEncode encodes an Unsubscribe packet. | ||||
| func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { | ||||
|  | ||||
| 	// Add the Packet ID. | ||||
| 	// [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 	if pk.PacketID == 0 { | ||||
| 		return ErrMissingPacketID | ||||
| 	} | ||||
|  | ||||
| 	packetID := encodeUint16(pk.PacketID) | ||||
|  | ||||
| 	// Count topics lengths. | ||||
| 	var topicsLen int | ||||
| 	for _, topic := range pk.Topics { | ||||
| 		topicsLen += len(encodeString(topic)) | ||||
| 	} | ||||
|  | ||||
| 	pk.FixedHeader.Remaining = len(packetID) + topicsLen | ||||
| 	pk.FixedHeader.Encode(buf) | ||||
| 	buf.Write(packetID) | ||||
|  | ||||
| 	// Add all provided topic names. | ||||
| 	for _, topic := range pk.Topics { | ||||
| 		buf.Write(encodeString(topic)) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // UnsubscribeDecode decodes an Unsubscribe packet. | ||||
| func (pk *Packet) UnsubscribeDecode(buf []byte) error { | ||||
| 	var offset int | ||||
| 	var err error | ||||
|  | ||||
| 	// Get the Packet ID. | ||||
| 	pk.PacketID, offset, err = decodeUint16(buf, 0) | ||||
| 	if err != nil { | ||||
| 		return ErrMalformedPacketID | ||||
| 	} | ||||
|  | ||||
| 	// Keep decoding until there's no space left. | ||||
| 	for offset < len(buf) { | ||||
| 		var t string | ||||
| 		t, offset, err = decodeString(buf, offset) // Decode Topic Name. | ||||
| 		if err != nil { | ||||
| 			return ErrMalformedTopic | ||||
| 		} | ||||
|  | ||||
| 		if len(t) > 0 { | ||||
| 			pk.Topics = append(pk.Topics, t) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
|  | ||||
| } | ||||
|  | ||||
| // UnsubscribeValidate validates an Unsubscribe packet. | ||||
| func (pk *Packet) UnsubscribeValidate() (byte, error) { | ||||
| 	// @SPEC [MQTT-2.3.1-1]. | ||||
| 	// SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. | ||||
| 	if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { | ||||
| 		return Failed, ErrMissingPacketID | ||||
| 	} | ||||
|  | ||||
| 	return Accepted, nil | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,343 +0,0 @@ | ||||
| package topics | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/server/internal/packets" | ||||
| ) | ||||
|  | ||||
| // Subscriptions is a map of subscriptions keyed on client. | ||||
| type Subscriptions map[string]byte | ||||
|  | ||||
| // Index is a prefix/trie tree containing topic subscribers and retained messages. | ||||
| type Index struct { | ||||
| 	mu   sync.RWMutex // a mutex for locking the whole index. | ||||
| 	Root *Leaf        // a leaf containing a message and more leaves. | ||||
| } | ||||
|  | ||||
| // New returns a pointer to a new instance of Index. | ||||
| func New() *Index { | ||||
| 	return &Index{ | ||||
| 		Root: &Leaf{ | ||||
| 			Leaves:  make(map[string]*Leaf), | ||||
| 			Clients: make(map[string]byte), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // RetainMessage saves a message payload to the end of a topic branch. Returns | ||||
| // 1 if a retained message was added, 0 if there was no change, and -1 if the | ||||
| // retained message was removed. | ||||
| func (x *Index) RetainMessage(msg packets.Packet) int64 { | ||||
| 	var q int64 | ||||
|  | ||||
| 	x.mu.Lock() | ||||
| 	defer x.mu.Unlock() | ||||
| 	n := x.poperate(msg.TopicName) | ||||
| 	if len(msg.Payload) > 0 { | ||||
| 		if n.Message.FixedHeader.Retain == false { | ||||
| 			q = 1 | ||||
| 		} | ||||
| 		n.Message = msg | ||||
| 	} else { | ||||
| 		if n.Message.FixedHeader.Retain == true { | ||||
| 			q = -1 | ||||
| 		} | ||||
| 		x.unpoperate(msg.TopicName, "", true) | ||||
| 	} | ||||
|  | ||||
| 	return q | ||||
| } | ||||
|  | ||||
| // Subscribe creates a subscription filter for a client. Returns true if the | ||||
| // subscription was new. | ||||
| func (x *Index) Subscribe(filter, client string, qos byte) bool { | ||||
| 	x.mu.Lock() | ||||
| 	defer x.mu.Unlock() | ||||
|  | ||||
| 	n := x.poperate(filter) | ||||
| 	_, ok := n.Clients[client] | ||||
| 	n.Clients[client] = qos | ||||
| 	n.Filter = filter | ||||
|  | ||||
| 	return !ok | ||||
| } | ||||
|  | ||||
| // Unsubscribe removes a subscription filter for a client. Returns true if an | ||||
| // unsubscribe action successful and the subscription existed. | ||||
| func (x *Index) Unsubscribe(filter, client string) bool { | ||||
| 	x.mu.Lock() | ||||
| 	defer x.mu.Unlock() | ||||
|  | ||||
| 	n := x.poperate(filter) | ||||
| 	_, ok := n.Clients[client] | ||||
|  | ||||
| 	return x.unpoperate(filter, client, false) && ok | ||||
| } | ||||
|  | ||||
| // unpoperate steps backward through a trie sequence and removes any orphaned | ||||
| // nodes. If a client id is specified, it will unsubscribe a client. If message | ||||
| // is true, it will delete a retained message. | ||||
| func (x *Index) unpoperate(filter string, client string, message bool) bool { | ||||
| 	var d int // Walk to end leaf. | ||||
| 	var particle string | ||||
| 	var hasNext = true | ||||
| 	e := x.Root | ||||
| 	for hasNext { | ||||
| 		particle, hasNext = isolateParticle(filter, d) | ||||
| 		d++ | ||||
| 		e, _ = e.Leaves[particle] | ||||
|  | ||||
| 		// If the topic part doesn't exist in the tree, there's nothing | ||||
| 		// left to do. | ||||
| 		if e == nil { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Step backward removing client and orphaned leaves. | ||||
| 	var key string | ||||
| 	var orphaned bool | ||||
| 	var end = true | ||||
| 	for e.Parent != nil { | ||||
| 		key = e.Key | ||||
|  | ||||
| 		// Wipe the client from this leaf if it's the filter end. | ||||
| 		if end { | ||||
| 			if client != "" { | ||||
| 				delete(e.Clients, client) | ||||
| 			} | ||||
| 			if message { | ||||
| 				e.Message = packets.Packet{} | ||||
| 			} | ||||
| 			end = false | ||||
| 		} | ||||
|  | ||||
| 		// If this leaf is empty, note it as orphaned. | ||||
| 		orphaned = len(e.Clients) == 0 && len(e.Leaves) == 0 && !e.Message.FixedHeader.Retain | ||||
|  | ||||
| 		// Traverse up the branch. | ||||
| 		e = e.Parent | ||||
|  | ||||
| 		// If the leaf we just came from was empty, delete it. | ||||
| 		if orphaned { | ||||
| 			delete(e.Leaves, key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
|  | ||||
| } | ||||
|  | ||||
| // poperate iterates and populates through a topic/filter path, instantiating | ||||
| // leaves as it goes and returning the final leaf in the branch. | ||||
| // poperate is a more enjoyable word than iterpop. | ||||
| func (x *Index) poperate(topic string) *Leaf { | ||||
| 	var d int | ||||
| 	var particle string | ||||
| 	var hasNext = true | ||||
| 	n := x.Root | ||||
| 	for hasNext { | ||||
| 		particle, hasNext = isolateParticle(topic, d) | ||||
| 		d++ | ||||
|  | ||||
| 		child, _ := n.Leaves[particle] | ||||
| 		if child == nil { | ||||
| 			child = &Leaf{ | ||||
| 				Key:     particle, | ||||
| 				Parent:  n, | ||||
| 				Leaves:  make(map[string]*Leaf), | ||||
| 				Clients: make(map[string]byte), | ||||
| 			} | ||||
| 			n.Leaves[particle] = child | ||||
| 		} | ||||
| 		n = child | ||||
| 	} | ||||
|  | ||||
| 	return n | ||||
| } | ||||
|  | ||||
| // Subscribers returns a map of clients who are subscribed to matching filters. | ||||
| func (x *Index) Subscribers(topic string) Subscriptions { | ||||
| 	x.mu.RLock() | ||||
| 	defer x.mu.RUnlock() | ||||
| 	return x.Root.scanSubscribers(topic, 0, make(Subscriptions)) | ||||
| } | ||||
|  | ||||
| // Messages returns a slice of retained topic messages which match a filter. | ||||
| func (x *Index) Messages(filter string) []packets.Packet { | ||||
| 	// ReLeaf("messages", x.Root, 0) | ||||
| 	x.mu.RLock() | ||||
| 	defer x.mu.RUnlock() | ||||
| 	return x.Root.scanMessages(filter, 0, make([]packets.Packet, 0, 32)) | ||||
| } | ||||
|  | ||||
| // Leaf is a child node on the tree. | ||||
| type Leaf struct { | ||||
| 	Key     string           // the key that was used to create the leaf. | ||||
| 	Parent  *Leaf            // a pointer to the parent node for the leaf. | ||||
| 	Leaves  map[string]*Leaf // a map of child nodes, keyed on particle id. | ||||
| 	Clients map[string]byte  // a map of client ids subscribed to the topic. | ||||
| 	Filter  string           // the path of the topic filter being matched. | ||||
| 	Message packets.Packet   // a message which has been retained for a specific topic. | ||||
| } | ||||
|  | ||||
| // scanSubscribers recursively steps through a branch of leaves finding clients who | ||||
| // have subscription filters matching a topic, and their highest QoS byte. | ||||
| func (l *Leaf) scanSubscribers(topic string, d int, clients Subscriptions) Subscriptions { | ||||
| 	part, hasNext := isolateParticle(topic, d) | ||||
|  | ||||
| 	// For either the topic part, a +, or a #, follow the branch. | ||||
| 	for _, particle := range []string{part, "+", "#"} { | ||||
|  | ||||
| 		// Topics beginning with the reserved $ character are restricted from | ||||
| 		// being returned for top level wildcards. | ||||
| 		if d == 0 && len(part) > 0 && part[0] == '$' && (particle == "+" || particle == "#") { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if child, ok := l.Leaves[particle]; ok { | ||||
|  | ||||
| 			// We're only interested in getting clients from the final | ||||
| 			// element in the topic, or those with wildhashes. | ||||
| 			if !hasNext || particle == "#" { | ||||
|  | ||||
| 				// Capture the highest QOS byte for any client with a filter | ||||
| 				// matching the topic. | ||||
| 				for client, qos := range child.Clients { | ||||
| 					if ex, ok := clients[client]; !ok || ex < qos { | ||||
| 						clients[client] = qos | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 				// Make sure we also capture any client who are listening | ||||
| 				// to this topic via path/# | ||||
| 				if !hasNext { | ||||
| 					if extra, ok := child.Leaves["#"]; ok { | ||||
| 						for client, qos := range extra.Clients { | ||||
| 							if ex, ok := clients[client]; !ok || ex < qos { | ||||
| 								clients[client] = qos | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// If this branch has hit a wildhash, just return immediately. | ||||
| 			if particle == "#" { | ||||
| 				return clients | ||||
| 			} else if hasNext { | ||||
| 				clients = child.scanSubscribers(topic, d+1, clients) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return clients | ||||
| } | ||||
|  | ||||
| // scanMessages recursively steps through a branch of leaves finding retained messages | ||||
| // that match a topic filter. Setting `d` to -1 will enable wildhash mode, and will | ||||
| // recursively check ALL child leaves in every subsequent branch. | ||||
| func (l *Leaf) scanMessages(filter string, d int, messages []packets.Packet) []packets.Packet { | ||||
|  | ||||
| 	// If a wildhash mode has been set, continue recursively checking through all | ||||
| 	// child leaves regardless of their particle key. | ||||
| 	if d == -1 { | ||||
| 		for _, child := range l.Leaves { | ||||
| 			if child.Message.FixedHeader.Retain { | ||||
| 				messages = append(messages, child.Message) | ||||
| 			} | ||||
| 			messages = child.scanMessages(filter, -1, messages) | ||||
| 		} | ||||
| 		return messages | ||||
| 	} | ||||
|  | ||||
| 	// Otherwise, we'll get the particle for d in the filter. | ||||
| 	particle, hasNext := isolateParticle(filter, d) | ||||
|  | ||||
| 	// If there's no more particles after this one, then take the messages from | ||||
| 	// these topics. | ||||
| 	if !hasNext { | ||||
|  | ||||
| 		// Wildcards and Wildhashes must be checked first, otherwise they | ||||
| 		// may be detected as standard particles, and not act properly. | ||||
| 		if particle == "+" || particle == "#" { | ||||
|  | ||||
| 			// Otherwise, if it's a wildcard or wildhash, get messages from all | ||||
| 			// the child leaves. This wildhash captures messages on the actual | ||||
| 			// wildhash position, whereas the d == -1 block collects subsequent | ||||
| 			// messages further down the branch. | ||||
| 			for _, child := range l.Leaves { | ||||
| 				if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' { | ||||
| 					continue | ||||
| 				} | ||||
| 				if child.Message.FixedHeader.Retain { | ||||
| 					messages = append(messages, child.Message) | ||||
| 				} | ||||
| 			} | ||||
| 		} else if child, ok := l.Leaves[particle]; ok { | ||||
| 			if child.Message.FixedHeader.Retain { | ||||
| 				messages = append(messages, child.Message) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	} else { | ||||
|  | ||||
| 		// If it's not the last particle, branch out to the next leaves, scanning | ||||
| 		// all available if it's a wildcard, or just one if it's a specific particle. | ||||
| 		if particle == "+" { | ||||
| 			for _, child := range l.Leaves { | ||||
| 				if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' { | ||||
| 					continue | ||||
| 				} | ||||
| 				messages = child.scanMessages(filter, d+1, messages) | ||||
| 			} | ||||
| 		} else if child, ok := l.Leaves[particle]; ok { | ||||
| 			messages = child.scanMessages(filter, d+1, messages) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// If the particle was a wildhash, scan all the child leaves setting the | ||||
| 	// d value to wildhash mode. | ||||
| 	if particle == "#" { | ||||
| 		for _, child := range l.Leaves { | ||||
| 			if d == 0 && len(child.Key) > 0 && child.Key[0] == '$' { | ||||
| 				continue | ||||
| 			} | ||||
| 			messages = child.scanMessages(filter, -1, messages) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return messages | ||||
| } | ||||
|  | ||||
| // isolateParticle extracts a particle between d / and d+1 / without allocations. | ||||
| func isolateParticle(filter string, d int) (particle string, hasNext bool) { | ||||
| 	var next, end int | ||||
| 	for i := 0; end > -1 && i <= d; i++ { | ||||
| 		end = strings.IndexRune(filter, '/') | ||||
| 		if d > -1 && i == d && end > -1 { | ||||
| 			hasNext = true | ||||
| 			particle = filter[next:end] | ||||
| 		} else if end > -1 { | ||||
| 			hasNext = false | ||||
| 			filter = filter[end+1:] | ||||
| 		} else { | ||||
| 			hasNext = false | ||||
| 			particle = filter[next:] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ReLeaf is a dev function for showing the trie leafs. | ||||
| /* | ||||
| func ReLeaf(m string, leaf *Leaf, d int) { | ||||
| 	for k, v := range leaf.Leaves { | ||||
| 		fmt.Println(m, d, strings.Repeat("  ", d), k) | ||||
| 		ReLeaf(m, v, d+1) | ||||
| 	} | ||||
| } | ||||
| */ | ||||
| @@ -1,484 +0,0 @@ | ||||
| package topics | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/server/internal/packets" | ||||
| ) | ||||
|  | ||||
| func TestNew(t *testing.T) { | ||||
| 	index := New() | ||||
| 	require.NotNil(t, index) | ||||
| 	require.NotNil(t, index.Root) | ||||
| } | ||||
|  | ||||
| func BenchmarkNew(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		New() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPoperate(t *testing.T) { | ||||
| 	index := New() | ||||
| 	child := index.poperate("path/to/my/mqtt") | ||||
| 	require.Equal(t, "mqtt", child.Key) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"]) | ||||
|  | ||||
| 	child = index.poperate("a/b/c/d/e") | ||||
| 	require.Equal(t, "e", child.Key) | ||||
| 	child = index.poperate("a/b/c/c/a") | ||||
| 	require.Equal(t, "a", child.Key) | ||||
| } | ||||
|  | ||||
| func BenchmarkPoperate(b *testing.B) { | ||||
| 	index := New() | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.poperate("path/to/my/mqtt") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestUnpoperate(t *testing.T) { | ||||
| 	index := New() | ||||
| 	index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1") | ||||
|  | ||||
| 	index.Subscribe("path/to/another/mqtt", "client-1", 0) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1") | ||||
|  | ||||
| 	pk := packets.Packet{TopicName: "path/to/retained/message", Payload: []byte{'h', 'e', 'l', 'l', 'o'}} | ||||
| 	index.RetainMessage(pk) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"]) | ||||
| 	require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["retained"].Leaves["message"].Message) | ||||
|  | ||||
| 	pk2 := packets.Packet{TopicName: "path/to/my/mqtt", Payload: []byte{'s', 'h', 'a', 'r', 'e', 'd'}} | ||||
| 	index.RetainMessage(pk2) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"]) | ||||
| 	require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message) | ||||
|  | ||||
| 	index.unpoperate("path/to/my/mqtt", "", true) // delete retained | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1") | ||||
| 	require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message.FixedHeader.Retain) | ||||
|  | ||||
| 	index.unpoperate("path/to/my/mqtt", "client-1", false) // unsubscribe client | ||||
| 	require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"]) | ||||
|  | ||||
| 	index.unpoperate("path/to/retained/message", "", true) // delete retained | ||||
| 	require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves, "my") | ||||
|  | ||||
| 	index.unpoperate("path/to/whatever", "client-1", false) // unsubscribe client | ||||
| 	require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"]) | ||||
|  | ||||
| 	//require.Empty(t, index.Root.Leaves["path"]) | ||||
|  | ||||
| } | ||||
|  | ||||
| func BenchmarkUnpoperate(b *testing.B) { | ||||
| 	index := New() | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.poperate("path/to/my/mqtt") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRetainMessage(t *testing.T) { | ||||
| 	pk := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 		}, | ||||
| 		TopicName: "path/to/my/mqtt", | ||||
| 		Payload:   []byte{'h', 'e', 'l', 'l', 'o'}, | ||||
| 	} | ||||
| 	pk2 := packets.Packet{ | ||||
| 		FixedHeader: packets.FixedHeader{ | ||||
| 			Retain: true, | ||||
| 		}, | ||||
| 		TopicName: "path/to/another/mqtt", | ||||
| 		Payload:   []byte{'h', 'e', 'l', 'l', 'o'}, | ||||
| 	} | ||||
|  | ||||
| 	index := New() | ||||
| 	q := index.RetainMessage(pk) | ||||
| 	require.Equal(t, int64(1), q) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"]) | ||||
| 	require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message) | ||||
|  | ||||
| 	index.Subscribe("path/to/another/mqtt", "client-1", 0) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients["client-1"]) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"]) | ||||
|  | ||||
| 	q = index.RetainMessage(pk2) | ||||
| 	require.Equal(t, int64(1), q) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"]) | ||||
| 	require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1") | ||||
|  | ||||
| 	q = index.RetainMessage(pk2) // already exsiting | ||||
| 	require.Equal(t, int64(0), q) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"]) | ||||
| 	require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1") | ||||
|  | ||||
| 	// Delete retained | ||||
| 	pk3 := packets.Packet{TopicName: "path/to/another/mqtt", Payload: []byte{}} | ||||
| 	q = index.RetainMessage(pk3) | ||||
| 	require.Equal(t, int64(-1), q) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"]) | ||||
| 	require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message) | ||||
| 	require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain) | ||||
| } | ||||
|  | ||||
| func BenchmarkRetainMessage(b *testing.B) { | ||||
| 	index := New() | ||||
| 	pk := packets.Packet{TopicName: "path/to/another/mqtt"} | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.RetainMessage(pk) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSubscribeOK(t *testing.T) { | ||||
| 	index := New() | ||||
|  | ||||
| 	q := index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 	require.Equal(t, true, q) | ||||
|  | ||||
| 	q = index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 	require.Equal(t, false, q) | ||||
|  | ||||
| 	q = index.Subscribe("path/to/my/mqtt", "client-2", 0) | ||||
| 	require.Equal(t, true, q) | ||||
|  | ||||
| 	q = index.Subscribe("path/to/another/mqtt", "client-1", 0) | ||||
| 	require.Equal(t, true, q) | ||||
|  | ||||
| 	q = index.Subscribe("path/+", "client-2", 0) | ||||
| 	require.Equal(t, true, q) | ||||
|  | ||||
| 	q = index.Subscribe("#", "client-3", 0) | ||||
| 	require.Equal(t, true, q) | ||||
|  | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1") | ||||
| 	require.Equal(t, "path/to/my/mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Filter) | ||||
| 	require.Equal(t, "mqtt", index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Key) | ||||
| 	require.Equal(t, index.Root.Leaves["path"], index.Root.Leaves["path"].Leaves["to"].Parent) | ||||
| 	require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-2") | ||||
|  | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1") | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["+"].Clients, "client-2") | ||||
| 	require.Contains(t, index.Root.Leaves["#"].Clients, "client-3") | ||||
| } | ||||
|  | ||||
| func BenchmarkSubscribe(b *testing.B) { | ||||
| 	index := New() | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.Subscribe("path/to/mqtt/basic", "client-1", 0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestUnsubscribeA(t *testing.T) { | ||||
| 	index := New() | ||||
| 	index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 	index.Subscribe("path/to/+/mqtt", "client-1", 0) | ||||
| 	index.Subscribe("path/to/stuff", "client-1", 0) | ||||
| 	index.Subscribe("path/to/stuff", "client-2", 0) | ||||
| 	index.Subscribe("#", "client-3", 0) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Clients, "client-1") | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1") | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1") | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2") | ||||
| 	require.Contains(t, index.Root.Leaves["#"].Clients, "client-3") | ||||
|  | ||||
| 	ok := index.Unsubscribe("path/to/my/mqtt", "client-1") | ||||
| 	require.Equal(t, true, ok) | ||||
|  | ||||
| 	require.Nil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"]) | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["+"].Leaves["mqtt"].Clients, "client-1") | ||||
|  | ||||
| 	ok = index.Unsubscribe("path/to/stuff", "client-1") | ||||
| 	require.Equal(t, true, ok) | ||||
|  | ||||
| 	require.NotContains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-1") | ||||
| 	require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["stuff"].Clients, "client-2") | ||||
| 	require.Contains(t, index.Root.Leaves["#"].Clients, "client-3") | ||||
|  | ||||
| 	ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "client-1") | ||||
| 	require.Equal(t, false, ok) | ||||
|  | ||||
| } | ||||
|  | ||||
| func TestUnsubscribeCascade(t *testing.T) { | ||||
| 	index := New() | ||||
| 	index.Subscribe("a/b/c", "client-1", 0) | ||||
| 	index.Subscribe("a/b/c/e/e", "client-1", 0) | ||||
|  | ||||
| 	ok := index.Unsubscribe("a/b/c/e/e", "client-1") | ||||
| 	require.Equal(t, true, ok) | ||||
| 	require.NotEmpty(t, index.Root.Leaves) | ||||
| 	require.Contains(t, index.Root.Leaves["a"].Leaves["b"].Leaves["c"].Clients, "client-1") | ||||
| } | ||||
|  | ||||
| // This benchmark is Unsubscribe-Subscribe | ||||
| func BenchmarkUnsubscribe(b *testing.B) { | ||||
| 	index := New() | ||||
|  | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 		index.Unsubscribe("path/to/mqtt/basic", "client-1") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSubscribersFind(t *testing.T) { | ||||
| 	tt := []struct { | ||||
| 		filter string | ||||
| 		topic  string | ||||
| 		len    int | ||||
| 	}{ | ||||
| 		{ | ||||
| 			filter: "a", | ||||
| 			topic:  "a", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "a/", | ||||
| 			topic:  "a", | ||||
| 			len:    0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "a/", | ||||
| 			topic:  "a/", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "/a", | ||||
| 			topic:  "/a", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "path/to/my/mqtt", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "path/to/+/mqtt", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "+/to/+/mqtt", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "#", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "+/+/+/+", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "+/+/+/#", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "zen/#", | ||||
| 			topic:  "zen", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "+/+/#", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "path/to/", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "#/stuff", | ||||
| 			topic:  "path/to/my/mqtt", | ||||
| 			len:    0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "$SYS/#", | ||||
| 			topic:  "$SYS/info", | ||||
| 			len:    1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "#", | ||||
| 			topic:  "$SYS/info", | ||||
| 			len:    0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			filter: "+/info", | ||||
| 			topic:  "$SYS/info", | ||||
| 			len:    0, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for i, check := range tt { | ||||
| 		index := New() | ||||
| 		index.Subscribe(check.filter, "client-1", 0) | ||||
| 		clients := index.Subscribers(check.topic) | ||||
| 		//spew.Dump(clients) | ||||
| 		require.Equal(t, check.len, len(clients), "Unexpected clients len at %d %s %s", i, check.filter, check.topic) | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func BenchmarkSubscribers(b *testing.B) { | ||||
| 	index := New() | ||||
| 	index.Subscribe("path/to/my/mqtt", "client-1", 0) | ||||
| 	index.Subscribe("path/to/+/mqtt", "client-1", 0) | ||||
| 	index.Subscribe("something/things/stuff/+", "client-1", 0) | ||||
| 	index.Subscribe("path/to/stuff", "client-2", 0) | ||||
| 	index.Subscribe("#", "client-3", 0) | ||||
|  | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.Subscribers("path/to/testing/mqtt") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIsolateParticle(t *testing.T) { | ||||
| 	particle, hasNext := isolateParticle("path/to/my/mqtt", 0) | ||||
| 	require.Equal(t, "path", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 1) | ||||
| 	require.Equal(t, "to", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 2) | ||||
| 	require.Equal(t, "my", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("path/to/my/mqtt", 3) | ||||
| 	require.Equal(t, "mqtt", particle) | ||||
| 	require.Equal(t, false, hasNext) | ||||
|  | ||||
| 	particle, hasNext = isolateParticle("/path/", 0) | ||||
| 	require.Equal(t, "", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("/path/", 1) | ||||
| 	require.Equal(t, "path", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("/path/", 2) | ||||
| 	require.Equal(t, "", particle) | ||||
| 	require.Equal(t, false, hasNext) | ||||
|  | ||||
| 	particle, hasNext = isolateParticle("a/b/c/+/+", 3) | ||||
| 	require.Equal(t, "+", particle) | ||||
| 	require.Equal(t, true, hasNext) | ||||
| 	particle, hasNext = isolateParticle("a/b/c/+/+", 4) | ||||
| 	require.Equal(t, "+", particle) | ||||
| 	require.Equal(t, false, hasNext) | ||||
| } | ||||
|  | ||||
| func BenchmarkIsolateParticle(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		isolateParticle("path/to/my/mqtt", 3) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMessagesPattern(t *testing.T) { | ||||
| 	tt := []struct { | ||||
| 		packet packets.Packet | ||||
| 		filter string | ||||
| 		len    int | ||||
| 	}{ | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "a/b/c/d", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"a/b/c/d", | ||||
| 			1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "a/b/c/e", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"a/+/c/+", | ||||
| 			2, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "a/b/d/f", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"+/+/+/+", | ||||
| 			3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "q/w/e/r/t/y", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"q/w/e/#", | ||||
| 			1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "q/w/x/r/t/x", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"q/#", | ||||
| 			2, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "asd", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"asd", | ||||
| 			1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "$SYS/testing", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"#", | ||||
| 			8, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "$SYS/test", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"+/testing", | ||||
| 			0, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "$SYS/info", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"$SYS/info", | ||||
| 			1, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "$SYS/b", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"$SYS/#", | ||||
| 			4, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "asd/fgh/jkl", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"#", | ||||
| 			8, | ||||
| 		}, | ||||
| 		{ | ||||
| 			packets.Packet{TopicName: "stuff/asdadsa/dsfdsafdsadfsa/dsfdsf/sdsadas", Payload: []byte{'h', 'e', 'l', 'l', 'o'}, FixedHeader: packets.FixedHeader{Retain: true}}, | ||||
| 			"stuff/#/things", // indexer will ignore trailing /things | ||||
| 			1, | ||||
| 		}, | ||||
| 	} | ||||
| 	index := New() | ||||
| 	for _, check := range tt { | ||||
| 		index.RetainMessage(check.packet) | ||||
| 	} | ||||
|  | ||||
| 	for i, check := range tt { | ||||
| 		messages := index.Messages(check.filter) | ||||
| 		require.Equal(t, check.len, len(messages), "Unexpected messages len at %d %s %s", i, check.filter, check.packet.TopicName) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMessagesFind(t *testing.T) { | ||||
| 	index := New() | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "a/a", Payload: []byte{'a'}, FixedHeader: packets.FixedHeader{Retain: true}}) | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "a/b", Payload: []byte{'b'}, FixedHeader: packets.FixedHeader{Retain: true}}) | ||||
| 	messages := index.Messages("a/a") | ||||
| 	require.Equal(t, 1, len(messages)) | ||||
|  | ||||
| 	messages = index.Messages("a/+") | ||||
| 	require.Equal(t, 2, len(messages)) | ||||
| } | ||||
|  | ||||
| func BenchmarkMessages(b *testing.B) { | ||||
| 	index := New() | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "path/to/my/mqtt"}) | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "path/to/another/mqtt"}) | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "path/a/some/mqtt"}) | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "what/is"}) | ||||
| 	index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"}) | ||||
|  | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		index.Messages("path/to/+/mqtt") | ||||
| 	} | ||||
| } | ||||
| @@ -1,12 +0,0 @@ | ||||
| package auth | ||||
|  | ||||
| // Controller is an interface for authentication controllers. | ||||
| type Controller interface { | ||||
|  | ||||
| 	// Authenticate authenticates a user on CONNECT and returns true if a user is | ||||
| 	// allowed to join the server. | ||||
| 	Authenticate(user, password []byte) bool | ||||
|  | ||||
| 	// ACL returns true if a user has read or write access to a given topic. | ||||
| 	ACL(user []byte, topic string, write bool) bool | ||||
| } | ||||
| @@ -1,31 +0,0 @@ | ||||
| package auth | ||||
|  | ||||
| // Allow is an auth controller which allows access to all connections and topics. | ||||
| type Allow struct{} | ||||
|  | ||||
| // Auth returns true if a username and password are acceptable. Allow always | ||||
| // returns true. | ||||
| func (a *Allow) Authenticate(user, password []byte) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // ACL returns true if a user has access permissions to read or write on a topic. | ||||
| // Allow always returns true. | ||||
| func (a *Allow) ACL(user []byte, topic string, write bool) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // Disallow is an auth controller which disallows access to all connections and topics. | ||||
| type Disallow struct{} | ||||
|  | ||||
| // Auth returns true if a username and password are acceptable. Disallow always | ||||
| // returns false. | ||||
| func (d *Disallow) Authenticate(user, password []byte) bool { | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // ACL returns true if a user has access permissions to read or write on a topic. | ||||
| // Disallow always returns false. | ||||
| func (d *Disallow) ACL(user []byte, topic string, write bool) bool { | ||||
| 	return false | ||||
| } | ||||
| @@ -1,55 +0,0 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestAllowAuth(t *testing.T) { | ||||
| 	ac := new(Allow) | ||||
| 	require.Equal(t, true, ac.Authenticate([]byte("user"), []byte("pass"))) | ||||
| } | ||||
|  | ||||
| func BenchmarkAllowAuth(b *testing.B) { | ||||
| 	ac := new(Allow) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		ac.Authenticate([]byte("user"), []byte("pass")) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAllowACL(t *testing.T) { | ||||
| 	ac := new(Allow) | ||||
| 	require.Equal(t, true, ac.ACL([]byte("user"), "topic", true)) | ||||
| } | ||||
|  | ||||
| func BenchmarkAllowACL(b *testing.B) { | ||||
| 	ac := new(Allow) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		ac.ACL([]byte("user"), "pass", true) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDisallowAuth(t *testing.T) { | ||||
| 	ac := new(Disallow) | ||||
| 	require.Equal(t, false, ac.Authenticate([]byte("user"), []byte("pass"))) | ||||
| } | ||||
|  | ||||
| func BenchmarkDisallowAuth(b *testing.B) { | ||||
| 	ac := new(Disallow) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		ac.Authenticate([]byte("user"), []byte("pass")) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDisallowACL(t *testing.T) { | ||||
| 	ac := new(Disallow) | ||||
| 	require.Equal(t, false, ac.ACL([]byte("user"), "topic", true)) | ||||
| } | ||||
|  | ||||
| func BenchmarkDisallowACL(b *testing.B) { | ||||
| 	ac := new(Disallow) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		ac.ACL([]byte("user"), "pass", true) | ||||
| 	} | ||||
| } | ||||
| @@ -1,121 +0,0 @@ | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/system" | ||||
| ) | ||||
|  | ||||
| // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. | ||||
| type HTTPStats struct { | ||||
| 	sync.RWMutex | ||||
| 	id      string       // the internal id of the listener. | ||||
| 	config  *Config      // configuration values for the listener. | ||||
| 	system  *system.Info // pointers to the server data. | ||||
| 	address string       // the network address to bind to. | ||||
| 	listen  *http.Server // the http server. | ||||
| 	end     int64        // ensure the close methods are only called once.} | ||||
| } | ||||
|  | ||||
| // NewHTTPStats initialises and returns a new HTTP listener, listening on an address. | ||||
| func NewHTTPStats(id, address string) *HTTPStats { | ||||
| 	return &HTTPStats{ | ||||
| 		id:      id, | ||||
| 		address: address, | ||||
| 		config: &Config{ | ||||
| 			Auth: new(auth.Allow), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // SetConfig sets the configuration values for the listener config. | ||||
| func (l *HTTPStats) SetConfig(config *Config) { | ||||
| 	l.Lock() | ||||
| 	if config != nil { | ||||
| 		l.config = config | ||||
|  | ||||
| 		// If a config has been passed without an auth controller, | ||||
| 		// it may be a mistake, so disallow all traffic. | ||||
| 		if l.config.Auth == nil { | ||||
| 			l.config.Auth = new(auth.Disallow) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	l.Unlock() | ||||
| } | ||||
|  | ||||
| // ID returns the id of the listener. | ||||
| func (l *HTTPStats) ID() string { | ||||
| 	l.RLock() | ||||
| 	id := l.id | ||||
| 	l.RUnlock() | ||||
| 	return id | ||||
| } | ||||
|  | ||||
| // Listen starts listening on the listener's network address. | ||||
| func (l *HTTPStats) Listen(s *system.Info) error { | ||||
| 	l.system = s | ||||
| 	mux := http.NewServeMux() | ||||
| 	mux.HandleFunc("/", l.jsonHandler) | ||||
| 	l.listen = &http.Server{ | ||||
| 		Addr:    l.address, | ||||
| 		Handler: mux, | ||||
| 	} | ||||
|  | ||||
| 	if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 { | ||||
| 		cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		l.listen.TLSConfig = &tls.Config{ | ||||
| 			Certificates: []tls.Certificate{cert}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Serve starts listening for new connections and serving responses. | ||||
| func (l *HTTPStats) Serve(establish EstablishFunc) { | ||||
| 	if l.listen.TLSConfig != nil { | ||||
| 		l.listen.ListenAndServeTLS("", "") | ||||
| 	} else { | ||||
| 		l.listen.ListenAndServe() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close closes the listener and any client connections. | ||||
| func (l *HTTPStats) Close(closeClients CloseFunc) { | ||||
| 	l.Lock() | ||||
| 	defer l.Unlock() | ||||
|  | ||||
| 	if atomic.LoadInt64(&l.end) == 0 { | ||||
| 		atomic.StoreInt64(&l.end, 1) | ||||
|  | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 		defer cancel() | ||||
| 		l.listen.Shutdown(ctx) | ||||
| 	} | ||||
|  | ||||
| 	closeClients(l.id) | ||||
| } | ||||
|  | ||||
| // jsonHandler is an HTTP handler which outputs the $SYS stats as JSON. | ||||
| func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { | ||||
| 	info, err := json.MarshalIndent(l.system, "", "\t") | ||||
| 	if err != nil { | ||||
| 		io.WriteString(w, err.Error()) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	w.Write(info) | ||||
| } | ||||
| @@ -1,186 +0,0 @@ | ||||
| package listeners | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/mochi-co/mqtt/server/listeners/auth" | ||||
| 	"github.com/mochi-co/mqtt/server/system" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| func TestNewHTTPStats(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	require.Equal(t, "t1", l.id) | ||||
| 	require.Equal(t, testPort, l.address) | ||||
| } | ||||
|  | ||||
| func BenchmarkNewHTTPStats(b *testing.B) { | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		NewHTTPStats("t1", testPort) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsSetConfig(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
|  | ||||
| 	l.SetConfig(&Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 	}) | ||||
| 	require.NotNil(t, l.config) | ||||
| 	require.NotNil(t, l.config.Auth) | ||||
| 	require.Equal(t, new(auth.Allow), l.config.Auth) | ||||
|  | ||||
| 	// Switch to disallow on bad config set. | ||||
| 	l.SetConfig(new(Config)) | ||||
| 	require.NotNil(t, l.config) | ||||
| 	require.NotNil(t, l.config.Auth) | ||||
| 	require.Equal(t, new(auth.Disallow), l.config.Auth) | ||||
| } | ||||
|  | ||||
| func BenchmarkHTTPStatsSetConfig(b *testing.B) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		l.SetConfig(new(Config)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsID(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	require.Equal(t, "t1", l.ID()) | ||||
| } | ||||
|  | ||||
| func BenchmarkHTTPStatsID(b *testing.B) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	for n := 0; n < b.N; n++ { | ||||
| 		l.ID() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsListen(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	err := l.Listen(new(system.Info)) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.NotNil(t, l.system) | ||||
| 	require.NotNil(t, l.listen) | ||||
| 	require.Equal(t, testPort, l.listen.Addr) | ||||
|  | ||||
| 	l.listen.Close() | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsListenTLS(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	l.SetConfig(&Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &TLS{ | ||||
| 			Certificate: testCertificate, | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	}) | ||||
| 	err := l.Listen(new(system.Info)) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, l.listen.TLSConfig) | ||||
| 	l.listen.Close() | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsListenTLSInvalid(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	l.SetConfig(&Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &TLS{ | ||||
| 			Certificate: []byte("abcde"), | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	}) | ||||
| 	err := l.Listen(new(system.Info)) | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsServeAndClose(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	err := l.Listen(&system.Info{ | ||||
| 		Version: "test", | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
| 	time.Sleep(time.Millisecond) | ||||
|  | ||||
| 	resp, err := http.Get("http://localhost" + testPort) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, resp) | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	v := new(system.Info) | ||||
| 	err = json.Unmarshal(body, v) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "test", v.Version) | ||||
|  | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
| 	require.Equal(t, true, closed) | ||||
|  | ||||
| 	_, err = http.Get("http://localhost" + testPort) | ||||
| 	require.Error(t, err) | ||||
|  | ||||
| 	<-o | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsServeTLSAndClose(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	l.SetConfig(&Config{ | ||||
| 		Auth: new(auth.Allow), | ||||
| 		TLS: &TLS{ | ||||
| 			Certificate: testCertificate, | ||||
| 			PrivateKey:  testPrivateKey, | ||||
| 		}, | ||||
| 	}) | ||||
| 	err := l.Listen(&system.Info{ | ||||
| 		Version: "test", | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	o := make(chan bool) | ||||
| 	go func(o chan bool) { | ||||
| 		l.Serve(MockEstablisher) | ||||
| 		o <- true | ||||
| 	}(o) | ||||
| 	time.Sleep(time.Millisecond) | ||||
| 	var closed bool | ||||
| 	l.Close(func(id string) { | ||||
| 		closed = true | ||||
| 	}) | ||||
| 	require.Equal(t, true, closed) | ||||
| } | ||||
|  | ||||
| func TestHTTPStatsJSONHandler(t *testing.T) { | ||||
| 	l := NewHTTPStats("t1", testPort) | ||||
| 	err := l.Listen(&system.Info{ | ||||
| 		Version: "test", | ||||
| 	}) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	w := httptest.NewRecorder() | ||||
| 	l.jsonHandler(w, nil) | ||||
| 	resp := w.Result() | ||||
| 	body, _ := ioutil.ReadAll(resp.Body) | ||||
|  | ||||
| 	v := new(system.Info) | ||||
| 	err = json.Unmarshal(body, v) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, "test", v.Version) | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user