Compare commits

..

132 Commits

Author SHA1 Message Date
JB
3a15cc3add Merge pull request #69 from mochi-co/v1.2.0
V1.2.0
2022-04-10 20:23:43 +01:00
mochi
765f6e7c2e use NewServer instead of New 2022-04-10 20:19:06 +01:00
mochi
e3bdfc1f8e update readme for new events 2022-04-10 20:18:18 +01:00
JB
60cd972b7f Merge pull request #68 from mochi-co/fix-store-retained
Fix Store Retained Messages
2022-04-10 19:55:34 +01:00
mochi
ee081d0abe Merge branch 'v1.2.0' of https://github.com/mochi-co/mqtt into fix-store-retained 2022-04-10 19:54:37 +01:00
JB
d97b4bb81d Merge pull request #67 from mochi-co/release-client-buffers
Release client buffers
2022-04-10 19:53:46 +01:00
mochi
94aeacf0cb only check final outcome due to races 2022-04-10 19:46:12 +01:00
mochi
5cb8a081a1 accept any error for invalid protocol due to races 2022-04-10 19:43:59 +01:00
mochi
fc00112e47 Check for protocol violation errors 2022-04-10 19:42:37 +01:00
mochi
bbc22fae5b Add comments 2022-04-10 19:30:26 +01:00
mochi
82cb75913d Remove unused code 2022-04-10 19:28:22 +01:00
mochi
45c4a64b87 Abandon client state if the existing client specified a cleansession 2022-04-10 19:07:46 +01:00
mochi
bae2579497 Expose CleanSession value for checking 2022-04-10 19:07:15 +01:00
mochi
1bc01271cb Store retained message based on corrected r value 2022-04-10 18:47:45 +01:00
mochi
352a71f50c Expect correct r values for RetainMessage 2022-04-10 18:46:46 +01:00
mochi
6f9f62e38f Correctly return R value of retainMessage 2022-04-10 18:46:23 +01:00
mochi
0f67d9e8ff Update EstablishConnection tests to ensure buffers and pool are correctly released after use 2022-04-10 17:37:46 +01:00
mochi
f2dd5b63ae Export R/W buffer values so they can be assessed in tests without causing races 2022-04-10 17:37:25 +01:00
mochi
54e2d044a2 Track number of pool blocks in use 2022-04-10 17:36:46 +01:00
mochi
6298a87298 Use package errors instead of strings 2022-04-10 14:58:59 +01:00
mochi
b6fd25bba4 test clarbuffers 2022-04-10 01:46:18 +01:00
mochi
eef3592576 clear buffers after deferred stop 2022-04-10 01:46:04 +01:00
mochi
5d343c12e1 refactor clients for buffer releasing 2022-04-10 01:34:01 +01:00
mochi
70f52c8a3b Refactor establishconnection to prevent same-id disconnects 2022-04-10 01:33:35 +01:00
mochi
429b72265a refactor connSetup for clarity 2022-04-10 01:32:24 +01:00
mochi
f60d2dcfca Clarify error messages 2022-04-10 01:32:04 +01:00
mochi
6674cd64eb clarify error checking 2022-04-10 01:31:36 +01:00
mochi
f218cde69c Use defer to release buffers and decrease stats on any client closure 2022-04-09 20:53:28 +01:00
JB
9ea687eb94 Merge pull request #61 from mochi-co/server-options
Configurable Server Options
2022-04-09 20:29:23 +01:00
JB
949e4e2e91 Merge pull request #63 from mochi-co/add-drop-packet-error
Add ErrRejectPacket to OnProcessMessage
2022-04-09 20:29:03 +01:00
JB
515e0269de Merge pull request #62 from mochi-co/fix-inflight-key
Fix Inflight Persistence Key
2022-04-06 11:23:51 +01:00
mochi
ae6073c79c track logged error 2022-03-31 18:23:12 +01:00
mochi
b072a08f0b Update test to check for packet rejection 2022-03-31 18:04:41 +01:00
mochi
01d8a450d2 Ensure OnError is set before using it 2022-03-31 17:56:08 +01:00
mochi
da2fd41f79 Update OnProcessMessage documentation 2022-03-31 17:53:34 +01:00
mochi
56e8039093 Optionally drop a packet if the ErrRejectPacket error is returned from OnProcessMessage 2022-03-31 17:53:13 +01:00
mochi
7b9bc844c1 Add ErrRejectPacket error to abandon packet processing from OnMessageProcess 2022-03-31 17:52:40 +01:00
JB
8acb182820 Merge pull request #53 from stffabi/feature/onprocessmessage-event
Events: Add OnProcessMessage event
2022-03-31 17:41:22 +01:00
mochi
5726880095 fix inflight key reference 2022-03-31 17:36:03 +01:00
mochi
ee459e1b3d Fix code block formatting 2022-03-31 17:32:17 +01:00
mochi
6aec3a8bbf Update readme with server options 2022-03-31 17:21:42 +01:00
mochi
74699f0a87 Add example implementation 2022-03-31 17:21:35 +01:00
mochi
70def39ff9 add tests for new NewServer function 2022-03-31 17:00:00 +01:00
mochi
8e7098a32d remove deprecated log message 2022-03-31 16:59:49 +01:00
mochi
e4f02919fd Update code to use new NewServer function instead of deprecated New 2022-03-31 16:49:27 +01:00
mochi
99c96c844e Update example code to use new NewServer function instead of deprecated New 2022-03-31 16:49:02 +01:00
mochi
18629aea6d Use internal default values instead of relying on passed value 2022-03-31 16:48:42 +01:00
mochi
a0060429d1 Add Server Options
Adds a new struct of server options which can be used to override default properties. A new options-accepting NewServer function has been created to supersede the New method, which is now deprecated.
2022-03-31 16:48:29 +01:00
mochi
d946a9ae16 Update go mod to ensure bolt is using 1.3.5
Bolt 1.3.6 fails to build correctly and has been removed, so rollback to bolt 1.3.5. Also upgrade to Go 1.18
2022-03-31 16:39:28 +01:00
JB
51a2eb5f48 Merge pull request #57 from hybridgroup/v1.2.0-docker
docker: add initial simple Dockerfile
2022-03-30 09:37:53 +01:00
JB
0d4b0a89d8 Merge pull request #58 from soyoo/patch-1
typo
2022-03-30 09:30:11 +01:00
soyoo
0e7ccfe3fb typo 2022-03-25 16:55:57 +08:00
Ron Evans
5d7230630d docker: add initial simple Dockerfile 2022-03-23 10:37:20 +01:00
JB
6a3cbd6093 Merge pull request #51 from jmacd/jmacd/noracefix
Two no-functional-change cleanups combined
2022-03-22 16:19:24 +00:00
stffabi
7b4e79707b Events: Add OnProcessMessage event
This event gets called right after ACL checking but before any other
Fields of the packet get evaluated.
2022-03-21 10:11:29 +01:00
Joshua MacDonald
c6643592f6 Combines two fixes 2022-03-17 14:19:23 -07:00
mochi
5de12d0460 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2 2022-03-17 12:53:59 +00:00
JB
0a7205e110 Update README.md 2022-03-17 12:50:01 +00:00
JB
8133dd8299 Update README.md 2022-03-17 12:48:37 +00:00
JB
fdbfff57dc Merge pull request #46 from stffabi/bugfix/acls-retain
Subscribe: Only send retained messages if ACLs has allowed subscription to the topic
2022-03-17 09:41:48 +00:00
stffabi
f5fc5e8c44 Subscribe: Only send retained messages if ACLs has allowed subscription to the topic 2022-03-17 09:25:29 +01:00
mochi
9f44712b80 Fix incorrect test
The previous publish inline test incorrectly approved retain packets without retain=true fixedheader values.
2022-03-16 18:16:48 +00:00
stffabi
1f86168d9d Publish: Set the retain flag in the fixedheader (#42)
* Publish: Set the retain flag in the fixedheader
2022-03-16 18:12:07 +00:00
mochi
ab25083ed2 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2
# Conflicts:
#	server/internal/clients/clients.go
2022-03-16 18:09:00 +00:00
JB
9b0aa4d559 Update README.md 2022-03-15 20:04:04 +00:00
JB
03814944a9 Update README.md 2022-03-15 19:58:58 +00:00
JB
3286d5a484 Replace Travis with Github Actions (#41)
* Remove Travis CI

* Add Github Actions Workflow

* Update badges for build status, coverage, report card, doc reference

* use actions for all pull requests and pushes

* test all files for coverage

* Apply gofmt -s to simplify code

* Fix typos

* Cleanup comments

* Cleanup comments

Co-authored-by: mochi <mochimou@icloud.com>
2022-03-15 19:56:42 +00:00
mochi
7e970d3c7a Fix typo 2022-03-15 19:13:24 +00:00
mochi
d6a92cc5bd Add Keyed fields to events.Client for readability and go vet 2022-03-15 18:44:49 +00:00
mochi
325d44d478 Add missing method comments 2022-03-15 18:44:21 +00:00
Joshua MacDonald
0a5f6d3a9d Add an OnError handler; report the reason for disconnects. (#38) 2022-03-15 17:59:52 +00:00
Joshua MacDonald
17253ad8bd Wrap packet errors with cause information (#39) 2022-03-15 17:34:49 +00:00
Joshua MacDonald
9f1c387091 Move two WaitGroup.Add calls (#36) 2022-03-15 17:33:31 +00:00
JB
9c6f602630 Merge pull request #29 from jmacd/jmacd/payload_not_utf8
Support non-UTF8 payloads (per MQTT specification)
2022-02-27 08:36:40 +00:00
Joshua MacDonald
b0dcaabdde Support non-UTF8 payloads per MQTT specification 2022-02-26 22:53:51 -08:00
JB
460f0ef681 revert redis update 2022-02-24 21:19:46 +00:00
JB
6e16765f60 revert server version 2022-02-24 21:19:22 +00:00
JB
2b361df19e Merge pull request #27 from mochi-co/revert-26-master
Revert "added redis persistence mode"
2022-02-24 21:15:08 +00:00
JB
c8c0a5a094 Revert "added redis persistence mode" 2022-02-24 21:10:39 +00:00
JB
4a833dd081 Update server version 2022-02-24 21:07:54 +00:00
JB
81198d9845 Update README.md 2022-02-24 21:07:24 +00:00
JB
6c12d8a71a Merge pull request #26 from wind-c/master
added redis persistence mode
2022-02-24 21:05:35 +00:00
narwal
19b598b672 redis and trie 2022-02-23 16:00:32 +08:00
narwal
b6529f05d3 add redis persistence mode and example 2022-02-22 18:57:31 +08:00
mochi
7f76445cc8 update server version 2022-01-30 10:39:49 +00:00
JB
b1c01792cd Merge pull request #24 from mochi-co/feature/optimise-struct-fields
Optimise Struct Fields + Fixes
2022-01-30 10:38:28 +00:00
mochi
eda03d4338 optimise Server struct 2022-01-30 10:30:34 +00:00
mochi
18070f1f57 pass byte pool by address 2022-01-30 10:30:19 +00:00
mochi
7f10c28a37 remove println 2022-01-30 10:30:01 +00:00
mochi
122531bb27 Pass inflight by address to avoid lock copying 2022-01-28 21:07:10 +00:00
mochi
e6dbcae428 Correct function signature 2022-01-28 21:06:57 +00:00
mochi
98875de568 Update test to match new FixedHeader struct 2022-01-28 21:06:43 +00:00
mochi
c9fd9451af Prevent locks from being copied 2022-01-28 21:06:24 +00:00
mochi
6550b8d680 8bit align struct fields 2022-01-28 21:05:50 +00:00
mochi
a60c96c889 Update comment for clarity 2022-01-28 21:04:15 +00:00
mochi
86e0a5827e Update version to 1.1.0 2022-01-26 20:49:53 +00:00
mochi
06c399b606 indicate ARM32 compatibility 2022-01-26 20:49:42 +00:00
JB
ed117f67a1 Merge pull request #22 from mochi-co/feature/32bit-compatibility
ARM32 Compatibility
2022-01-26 20:36:13 +00:00
JB
880a3299e1 Merge pull request #19 from rkennedy/bugfix/32-bit-atomic-alignment
Improve 32-bit compatibility
2022-01-26 08:02:57 +00:00
Rob Kennedy
1c408d05be Fix encodeLength for 32-bit platforms
When `int` is 32 bits, `MaxInt64` doesn't fit. It's apparent that
`encodeLength` expects to handle 64-bit inputs, so let's make that
explicit, which allows the test to run on all platforms.
2022-01-25 00:22:26 -06:00
Rob Kennedy
fce495f83e Avoid race condition when closing listeners
"Atomic load" followed by "atomic store" is not itself an atomic
operation. This commit replaces that sequence with CompareAndSwap
instead.
2022-01-25 00:22:26 -06:00
Rob Kennedy
471ca00a64 Make atomics work on 32-bit systems
On 32-bit systems, `atomic` requires its 64-bit arguments to have 64-bit
alignment, but the compiler doesn't help ensure that's the case. In this
commit, fields that don't need to hold large numbers have been converted
to 32-bit types, which are always aligned correctly on all platforms.
For fields that may hold large numeric values, padding has been added to
get the necessary alignment, and tests have been added to avoid
regressions.
2022-01-25 00:22:26 -06:00
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
88 changed files with 2720 additions and 1244 deletions

43
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.17
- name: Vet
run: go vet ./...
- name: Test
run: go test -v -race ./...
coverage:
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.17.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cmd/mqtt
.DS_Store
.DS_Store
server/persistence/bolt/testbolt.db

View File

@@ -1,18 +0,0 @@
dist: xenial
language: go
env:
- GO111MODULE=on
go:
- 1.13.x
git:
depth: 1
script:
- go test -v -race ./... -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)

31
Dockerfile Normal file
View File

@@ -0,0 +1,31 @@
FROM golang:1.18.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

View File

@@ -1,10 +1,11 @@
<p align="center">
[![Build Status](https://travis-ci.com/mochi-co/mqtt.svg?token=59nqixhtefy2iQRwsPcu&branch=master)](https://travis-ci.com/mochi-co/mqtt)
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
[![codecov](https://codecov.io/gh/mochi-co/mqtt/branch/master/graph/badge.svg?token=6vBUgYVaVB)](https://codecov.io/gh/mochi-co/mqtt)
[![GoDoc](https://godoc.org/github.com/mochi-co/mqtt?status.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
</p>
@@ -13,6 +14,10 @@
Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with everyone's favourites such as Mosquitto, Mosca, and VerneMQ.
> #### 📦 💬 See Github Discussions for discussions about releases
> Ongoing discussion about current and future releases can be found at https://github.com/mochi-co/mqtt/discussions
#### What is MQTT?
MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq)
@@ -23,11 +28,13 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Ring Buffer packet codec.
- TCP, Websocket, (including SSL/TLS) and Dashboard listeners.
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
- Bolt persistence and storage interfaces (see examples folder).
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (currently `onMessage`)
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible.
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
@@ -47,7 +54,7 @@ import (
func main() {
// Create the new MQTT Server.
server := mqtt.New()
server := mqtt.NewServer(nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883")
@@ -105,11 +112,42 @@ err := server.AddListener(tcp, &listeners.Config{
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
`server.Events.OnMessage` is called when a Publish packet (message) is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
##### OnProcessMessage
`server.Events.OnProcessMessage` is called before a publish packet (message) is processed. Specifically, the method callback is triggered after topic and ACL validation has occurred, but before the headers and payload are processed. You can use this if you want to programmatically change the data of the packet, such as setting it to retain, or altering the QoS flag.
If an error is returned, the packet will not be modified. and the existing packet will be used. If this is an unwanted outcome, the `mqtt.ErrRejectPacket` error can be returned from the callback, and the packet will be dropped/ignored, any further processing is abandoned.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
```go
import "github.com/mochi-co/mqtt/server/events"
@@ -124,7 +162,34 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
}
```
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
##### OnError
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
##### OnStorage
`server.Events.OnStorage` is like `onError`, but receives the output of persistent storage methods.
#### Server Options
A few options can be passed to the `mqtt.NewServer(opts *Options)` function in order to override the default broker configuration. Currently these options are:
- BufferSize (default 1024 * 256 bytes) - The default value is sufficient for most messaging sizes, but if you are sending many kilobytes of data (such as images), you should increase this to a value of (n*s) where is the typical size of your message and n is the number of messages you may have backlogged for a client at any given time.
- BufferBlockSize (default 1024 * 8) - The minimum size in which R/W data will be allocated. If you are expecting only tiny or large payloads, you can alter this accordingly.
Any options which is not set or is `0` will use default values.
```go
opts := &mqtt.Options{
BufferSize: 512 * 1024,
BufferBlockSize: 16 * 1024,
}
s := mqtt.NewServer(opts)
```
> See `examples/tcp/main.go` for an example implementation.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
@@ -154,7 +219,7 @@ if err != nil {
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder.
#### Performance (messages/second)
#### Performance at v1.0.0
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a 13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. SEND = Publish throughput, RECV = Subscribe throughput.
> As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers.

View File

@@ -33,7 +33,7 @@ func main() {
fmt.Println(aurora.Cyan("Websocket"), *wsAddr)
fmt.Println(aurora.Cyan("$SYS Dashboard"), *infoAddr)
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", *tcpAddr)
err := server.AddListener(tcp, nil)
if err != nil {

View File

@@ -24,7 +24,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
stats := listeners.NewHTTPStats("stats", ":8080")
err := server.AddListener(stats, nil)

View File

@@ -27,7 +27,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
@@ -44,6 +44,16 @@ func main() {
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
@@ -54,6 +64,12 @@ func main() {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}

View File

@@ -46,7 +46,7 @@ func main() {
// Auth is an example auth provider for the server.
type Auth struct{}
// Auth returns true if a username and password are acceptable.
// Authenticate returns true if a username and password are acceptable.
// Auth always returns true.
func (a *Auth) Authenticate(user, password []byte) bool {
return true
@@ -55,8 +55,5 @@ func (a *Auth) Authenticate(user, password []byte) bool {
// ACL returns true if a user has access permissions to read or write on a topic.
// ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
if topic == "test/nosubscribe" {
return false
}
return true
return topic != "test/nosubscribe"
}

View File

@@ -28,7 +28,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -25,7 +25,13 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
// An example of configuring various server options...
options := &mqtt.Options{
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
}
server := mqtt.NewServer(options)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -57,7 +57,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -24,7 +24,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
ws := listeners.NewWebsocket("ws1", ":1882")
err := server.AddListener(ws, nil)
if err != nil {

12
go.mod
View File

@@ -1,16 +1,16 @@
module github.com/mochi-co/mqtt
go 1.17
go 1.18
require (
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/gorilla/websocket v1.4.2
github.com/jinzhu/copier v0.3.4
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/rs/xid v1.3.0
github.com/stretchr/testify v1.7.0
go.etcd.io/bbolt v1.3.6
github.com/rs/xid v1.4.0
github.com/stretchr/testify v1.7.1
go.etcd.io/bbolt v1.3.5
)
require (

20
go.sum
View File

@@ -14,10 +14,10 @@ github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.3.4 h1:mfU6jI9PtCeUjkjQ322dlff9ELjGDu975C2p/nrubVI=
github.com/jinzhu/copier v0.3.4/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
@@ -27,17 +27,17 @@ github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczG
github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU=
go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg=

View File

@@ -1,29 +1,49 @@
package events
import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Events provides callback handlers for different event hooks.
type Events struct {
OnMessage // published message receieved.
OnProcessMessage // published message receieved before evaluation.
OnMessage // published message receieved.
OnError // server error.
OnConnect // client connected.
OnDisconnect // client disconnected.
}
// Packets is an alias for packets.Packet.
type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
}
// FromClient returns an event client from a client.
func FromClient(cl clients.Client) Client {
return Client{
ID: cl.ID,
Listener: cl.Listener,
}
// Clientlike is an interface for Clients and client-like objects that
// are able to describe their client/listener IDs and remote address.
type Clientlike interface {
Info() Client
}
// OnProcessMessage is called when a publish message is received, allowing modification
// of the packet data after ACL checking has occurred but before any data is evaluated
// for processing - e.g. for changing the Retain flag. Note, this hook is ONLY called
// by connected client publishers, it is not triggered when using the direct
// s.Publish method. The function receives the sent message and the
// data of the client who published it, and allows the packet to be modified
// before it is dispatched to subscribers. If no modification is required, return
// the original packet data. If an error occurs, the original packet will
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)
// OnMessage function is called when a publish message is received. Note,
// this hook is ONLY called by connected client publishers, it is not triggered when
// using the direct s.Publish method. The function receives the sent message and the
@@ -34,3 +54,15 @@ func FromClient(cl clients.Client) Client {
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)
// OnError is called when errors that will not be passed to
// OnDisconnect are handled by the server.
type OnError func(Client, error)

View File

@@ -8,33 +8,39 @@ import (
)
var (
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
// DefaultBufferSize is the default size of the buffer in bytes.
DefaultBufferSize int = 1024 * 256
ErrOutOfRange = errors.New("Indexes out of range")
// DefaultBlockSize is the default size per R/W block in bytes.
DefaultBlockSize int = 1024 * 8
// ErrOutOfRange indicates that the index was out of range.
ErrOutOfRange = errors.New("Indexes out of range")
// ErrInsufficientBytes indicates that there were not enough bytes to return.
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// buffer contains core values and methods to be included in a reader or writer.
// Buffer is a circular buffer for reading and writing messages.
type Buffer struct {
Mu sync.RWMutex // the buffer needs it's own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
done int64 // indicates that the buffer is closed.
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) Buffer {
func NewBuffer(size, block int) *Buffer {
if size == 0 {
size = DefaultBufferSize
}
@@ -47,7 +53,7 @@ func NewBuffer(size, block int) Buffer {
size = 2 * block
}
return Buffer{
return &Buffer{
size: size,
mask: size - 1,
block: block,
@@ -59,14 +65,14 @@ func NewBuffer(size, block int) Buffer {
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) Buffer {
func NewBufferFromSlice(block int, buf []byte) *Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := Buffer{
b := &Buffer{
size: l,
mask: l - 1,
block: block,
@@ -78,7 +84,7 @@ func NewBufferFromSlice(block int, buf []byte) Buffer {
return b
}
// Get will return the tail and head positions of the buffer.
// GetPos will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
@@ -127,7 +133,7 @@ func (b *Buffer) awaitEmpty(n int) error {
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
@@ -146,7 +152,7 @@ func (b *Buffer) awaitFilled(n int) error {
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
@@ -196,7 +202,7 @@ func (b *Buffer) CapDelta() int {
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreInt64(&b.done, 1)
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()

View File

@@ -5,6 +5,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, int64(1), buf.done)
require.Equal(t, uint32(1), buf.done)
}

View File

@@ -2,17 +2,23 @@ package circ
import (
"sync"
"sync/atomic"
)
// BytesPool is a pool of []byte.
type BytesPool struct {
pool sync.Pool
pool *sync.Pool
used int64
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) BytesPool {
return BytesPool{
pool: sync.Pool{
func NewBytesPool(n int) *BytesPool {
if n == 0 {
n = DefaultBufferSize
}
return &BytesPool{
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
@@ -22,6 +28,7 @@ func NewBytesPool(n int) BytesPool {
// Get returns a pooled bytes.Buffer.
func (b *BytesPool) Get() []byte {
atomic.AddInt64(&b.used, 1)
return b.pool.Get().([]byte)
}
@@ -31,4 +38,10 @@ func (b *BytesPool) Put(x []byte) {
x[i] = 0
}
b.pool.Put(x)
atomic.AddInt64(&b.used, -1)
}
// InUse returns the number of pool blocks in use.
func (b *BytesPool) InUse() int64 {
return atomic.LoadInt64(&b.used)
}

View File

@@ -22,6 +22,7 @@ func TestNewBytesPoolGet(t *testing.T) {
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
require.Equal(t, int64(1), bpool.InUse())
}
func BenchmarkBytesPoolGet(b *testing.B) {
@@ -34,7 +35,9 @@ func BenchmarkBytesPoolGet(b *testing.B) {
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, int64(1), bpool.InUse())
bpool.Put(buf)
require.Equal(t, int64(0), bpool.InUse())
}
func BenchmarkBytesPoolPut(b *testing.B) {

View File

@@ -7,7 +7,7 @@ import (
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
Buffer
*Buffer
}
// NewReader returns a new Circular Reader.
@@ -19,7 +19,7 @@ func NewReader(size, block int) *Reader {
}
}
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreInt64(&b.State, 1)
defer atomic.StoreInt64(&b.State, 0)
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}
@@ -60,7 +60,7 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
return total, nil
return total, err
}
// Move the head forward however many bytes were read.

View File

@@ -2,6 +2,8 @@ package circ
import (
"bytes"
"errors"
"io"
"sync/atomic"
"testing"
"time"
@@ -34,16 +36,18 @@ func TestReadFrom(t *testing.T) {
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.NoError(t, err)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(12), buf.head)
}
@@ -60,7 +64,7 @@ func TestReadFromWrap(t *testing.T) {
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -116,7 +120,7 @@ func TestReadEnded(t *testing.T) {
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -1,14 +1,14 @@
package circ
import (
"fmt"
"io"
"log"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
Buffer
*Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
@@ -20,7 +20,7 @@ func NewWriter(size, block int) *Writer {
}
}
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
@@ -31,11 +31,11 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2)
defer atomic.StoreInt64(&b.State, 0)
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
@@ -59,14 +59,12 @@ func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
p = append(p, b.buf[rTail:rHead]...)
}
//fmt.Println("writing", p)
n, err = w.Write(p)
total += n
total += int64(n)
if err != nil {
fmt.Println("writing err", err)
log.Println("error writing to buffer io.Writer;", err)
return
}
//fmt.Println("written", n)
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))

View File

@@ -52,14 +52,14 @@ func TestWriteTo(t *testing.T) {
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int)
nc := make(chan int64)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -12,6 +12,7 @@ import (
"github.com/rs/xid"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
@@ -23,7 +24,9 @@ var (
// defaultKeepalive is the default connection keepalive value in seconds.
defaultKeepalive uint16 = 10
ErrConnectionClosed = errors.New("Connection not open")
// ErrConnectionClosed is returned when operating on a closed
// connection and/or when no error cause has been given.
ErrConnectionClosed = errors.New("connection not open")
)
// Clients contains a map of the clients known by the broker.
@@ -74,7 +77,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
@@ -84,42 +87,43 @@ func (cl *Clients) GetByListener(id string) []*Client {
// Client contains information about a client known by the broker.
type Client struct {
sync.RWMutex
conn net.Conn // the net.Conn used to establish the connection.
r *circ.Reader // a reader for reading incoming bytes.
w *circ.Writer // a writer for writing outgoing bytes.
ID string // the client id.
AC auth.Controller // an auth controller inherited from the listener.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
Listener string // the id of the listener the client is connected to.
Inflight Inflight // a map of in-flight qos messages.
Username []byte // the username the client authenticated with.
keepalive uint16 // the number of seconds the connection can wait.
cleanSession bool // indicates if the client expects a clean-session.
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
LWT LWT // the last will and testament for the client.
Inflight *Inflight // a map of in-flight qos messages.
sync.RWMutex // mutex
Username []byte // the username the client authenticated with.
AC auth.Controller // an auth controller inherited from the listener.
Listener string // the id of the listener the client is connected to.
ID string // the client id.
conn net.Conn // the net.Conn used to establish the connection.
R *circ.Reader // a reader for reading incoming bytes.
W *circ.Writer // a writer for writing outgoing bytes.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
systemInfo *system.Info // pointers to server system info.
packetID uint32 // the current highest packetID.
keepalive uint16 // the number of seconds the connection can wait.
CleanSession bool // indicates if the client expects a clean-session.
}
// State tracks the state of the client.
type State struct {
Done int64 // atomic counter which indicates that the client has closed.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
endOnce sync.Once // only end once.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
Done uint32 // atomic counter which indicates that the client has closed.
endOnce sync.Once // only end once.
stopCause atomic.Value // reason for stopping.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
conn: c,
R: r,
W: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -139,7 +143,7 @@ func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Clie
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: Inflight{
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -159,11 +163,11 @@ func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
cl.ID = xid.New().String()
}
cl.r.ID = cl.ID + " READER"
cl.w.ID = cl.ID + " WRITER"
cl.R.ID = cl.ID + " READER"
cl.W.ID = cl.ID + " WRITER"
cl.Username = pk.Username
cl.cleanSession = pk.CleanSession
cl.CleanSession = pk.CleanSession
cl.keepalive = pk.Keepalive
if pk.WillFlag {
@@ -185,7 +189,20 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
}
cl.conn.SetDeadline(expiry)
_ = cl.conn.SetDeadline(expiry)
}
}
// Info returns an event-version of a client, containing minimal information.
func (cl *Client) Info() events.Client {
addr := "unknown"
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Listener: cl.Listener,
}
}
@@ -218,47 +235,75 @@ func (cl *Client) ForgetSubscription(filter string) {
// Start begins the client goroutines reading and writing packets.
func (cl *Client) Start() {
cl.State.started.Add(2)
go func() {
cl.State.started.Done()
cl.w.WriteTo(cl.conn)
cl.State.endedW.Done()
cl.Stop()
}()
cl.State.endedW.Add(1)
cl.State.endedR.Add(1)
go func() {
cl.State.started.Done()
cl.r.ReadFrom(cl.conn)
cl.State.endedR.Done()
cl.Stop()
_, err := cl.W.WriteTo(cl.conn)
if err != nil {
err = fmt.Errorf("writer: %w", err)
}
cl.State.endedW.Done()
cl.Stop(err)
}()
go func() {
cl.State.started.Done()
_, err := cl.R.ReadFrom(cl.conn)
if err != nil {
err = fmt.Errorf("reader: %w", err)
}
cl.State.endedR.Done()
cl.Stop(err)
}()
cl.State.endedR.Add(1)
cl.State.started.Wait()
}
// ClearBuffers sets the read/write buffers to nil so they can be
// deallocated automatically when no longer in use.
func (cl *Client) ClearBuffers() {
cl.R = nil
cl.W = nil
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
// A cause error may be passed to identfy the reason for stopping.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.r.Stop()
cl.w.Stop()
cl.R.Stop()
cl.W.Stop()
cl.State.endedW.Wait()
cl.conn.Close()
_ = cl.conn.Close() // omit close error
cl.State.endedR.Wait()
atomic.StoreInt64(&cl.State.Done, 1)
atomic.StoreUint32(&cl.State.Done, 1)
if err == nil {
err = ErrConnectionClosed
}
cl.State.stopCause.Store(err)
})
}
// readFixedHeader reads in the values of the next packet's fixed header.
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
p, err := cl.r.Read(1)
p, err := cl.R.Read(1)
if err != nil {
return err
}
@@ -275,7 +320,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
i := 1
n := 2
for ; n < 6; n++ {
p, err = cl.r.Read(n)
p, err = cl.R.Read(n)
if err != nil {
return err
}
@@ -299,8 +344,8 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
fh.Remaining = int(rem)
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
cl.R.CommitTail(n)
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
@@ -309,7 +354,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
return nil
}
@@ -334,18 +379,18 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
return
}
p, err := cl.r.Read(pk.FixedHeader.Remaining)
p, err := cl.R.Read(pk.FixedHeader.Remaining)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
@@ -359,7 +404,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
@@ -384,19 +429,19 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
cl.r.CommitTail(pk.FixedHeader.Remaining)
cl.R.CommitTail(pk.FixedHeader.Remaining)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
cl.w.Mu.Lock()
defer cl.w.Mu.Unlock()
cl.W.Mu.Lock()
defer cl.W.Mu.Unlock()
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
@@ -407,7 +452,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
@@ -439,12 +484,13 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
}
// Write the packet bytes to the client byte buffer.
n, err = cl.w.Write(buf.Bytes())
n, err = cl.W.Write(buf.Bytes())
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
@@ -453,8 +499,8 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/listeners/auth"
@@ -16,6 +17,8 @@ import (
"github.com/stretchr/testify/require"
)
var testClientStop = errors.New("test stop")
func genClient() *Client {
c, _ := net.Pipe()
return NewClient(c, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
@@ -130,11 +133,39 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl)
require.NotNil(t, cl.Inflight.internal)
require.NotNil(t, cl.Subscriptions)
require.NotNil(t, cl.r)
require.NotNil(t, cl.w)
require.NotNil(t, cl.State.started)
require.NotNil(t, cl.State.endedW)
require.NotNil(t, cl.State.endedR)
require.NotNil(t, cl.R)
require.NotNil(t, cl.W)
require.Nil(t, cl.StopCause())
}
func TestClientInfoUnknown(t *testing.T) {
cl := genClient()
cl.ID = "testid"
cl.Listener = "testlistener"
cl.conn = nil
require.Equal(t, events.Client{
ID: "testid",
Remote: "unknown",
Listener: "testlistener",
}, cl.Info())
}
func TestClientInfoKnown(t *testing.T) {
c1, c2 := net.Pipe()
defer c1.Close()
defer c2.Close()
cl := genClient()
cl.ID = "ID"
cl.Listener = "L"
cl.conn = c1
require.Equal(t, events.Client{
ID: "ID",
Remote: c1.RemoteAddr().String(),
Listener: "L",
}, cl.Info())
}
func BenchmarkNewClient(b *testing.B) {
@@ -175,7 +206,7 @@ func TestClientIdentify(t *testing.T) {
cl.Identify("tcp1", pk, new(auth.Allow))
require.Equal(t, pk.Keepalive, cl.keepalive)
require.Equal(t, pk.CleanSession, cl.cleanSession)
require.Equal(t, pk.CleanSession, cl.CleanSession)
require.Equal(t, pk.ClientIdentifier, cl.ID)
}
@@ -314,15 +345,15 @@ func BenchmarkClientRefreshDeadline(b *testing.B) {
func TestClientStart(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
time.Sleep(time.Millisecond)
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.R.State))
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.W.State))
}
func BenchmarkClientStart(b *testing.B) {
cl := genClient()
defer cl.Stop()
defer cl.Stop(testClientStop)
for n := 0; n < b.N; n++ {
cl.Start()
@@ -332,17 +363,17 @@ func BenchmarkClientStart(b *testing.B) {
func TestClientReadFixedHeader(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.r.SetPos(0, 2)
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.R.SetPos(0, 2)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
tail, head := cl.r.GetPos()
tail, head := cl.R.GetPos()
require.Equal(t, int64(2), tail)
require.Equal(t, int64(2), head)
@@ -351,13 +382,13 @@ func TestClientReadFixedHeader(t *testing.T) {
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
cl.r.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
cl.r.SetPos(0, 2)
cl.R.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
cl.R.SetPos(0, 2)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
@@ -367,17 +398,17 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.r.SetPos(0, 1)
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.R.SetPos(0, 1)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
cl.r.Stop()
cl.R.Stop()
err := <-o
require.Error(t, err)
require.Equal(t, io.EOF, err)
@@ -386,14 +417,14 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
err := cl.r.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
err := cl.R.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
require.NoError(t, err)
cl.r.SetPos(0, 5)
cl.R.SetPos(0, 5)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
@@ -403,7 +434,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
func TestClientReadOK(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
// Two packets in a row...
b := []byte{
@@ -417,9 +448,9 @@ func TestClientReadOK(t *testing.T) {
'y', 'e', 'a', 'h', // Payload
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
var pks []packets.Packet
@@ -431,7 +462,7 @@ func TestClientReadOK(t *testing.T) {
}()
time.Sleep(time.Millisecond)
cl.r.Stop()
cl.R.Stop()
err = <-o
require.Error(t, err)
@@ -456,15 +487,25 @@ func TestClientReadOK(t *testing.T) {
},
})
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
}
func TestClientClearBuffers(t *testing.T) {
cl := genClient()
cl.Start()
cl.Stop(testClientStop)
cl.ClearBuffers()
require.Nil(t, cl.W)
require.Nil(t, cl.R)
}
func TestClientReadDone(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
cl.State.Done = 1
err := cl.Read(func(cl *Client, pk packets.Packet) error {
@@ -477,7 +518,7 @@ func TestClientReadDone(t *testing.T) {
func TestClientReadPacketError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
b := []byte{
0, 18,
@@ -485,9 +526,9 @@ func TestClientReadPacketError(t *testing.T) {
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i',
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
go func() {
@@ -499,10 +540,37 @@ func TestClientReadPacketError(t *testing.T) {
require.Error(t, <-o)
}
func TestClientReadPacketEOF(t *testing.T) {
cl := genClient()
cl.Start()
b := []byte{
0, 18,
0, 5,
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', // missing 1 byte
}
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
cl.R.Stop()
cl.Stop(testClientStop)
require.Error(t, <-o)
require.True(t, errors.Is(cl.StopCause(), testClientStop))
}
func TestClientReadHandlerErr(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
b := []byte{
byte(packets.Publish << 4), 11, // Fixed header
@@ -511,9 +579,9 @@ func TestClientReadHandlerErr(t *testing.T) {
'y', 'e', 'a', 'h', // Payload
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
err = cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
@@ -525,16 +593,16 @@ func TestClientReadHandlerErr(t *testing.T) {
func TestClientReadPacketOK(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
err := cl.r.Set([]byte{
err := cl.R.Set([]byte{
byte(packets.Publish << 4), 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
}, 0, 13)
require.NoError(t, err)
cl.r.SetPos(0, 13)
cl.R.SetPos(0, 13)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
@@ -557,12 +625,12 @@ func TestClientReadPacketOK(t *testing.T) {
func TestClientReadPacket(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
for i, tt := range pkTable {
err := cl.r.Set(tt.bytes, 0, len(tt.bytes))
err := cl.R.Set(tt.bytes, 0, len(tt.bytes))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(tt.bytes)))
cl.R.SetPos(0, int64(len(tt.bytes)))
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
@@ -574,7 +642,7 @@ func TestClientReadPacket(t *testing.T) {
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
}
}
}
@@ -582,16 +650,16 @@ func TestClientReadPacket(t *testing.T) {
func TestClientReadPacketReadingError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
err := cl.r.Set([]byte{
err := cl.R.Set([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
}, 0, 13)
require.NoError(t, err)
cl.r.SetPos(2, 13)
cl.R.SetPos(2, 13)
_, err = cl.ReadPacket(&packets.FixedHeader{
Type: 0,
@@ -603,8 +671,8 @@ func TestClientReadPacketReadingError(t *testing.T) {
func TestClientReadPacketReadError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
cl.r.Stop()
defer cl.Stop(testClientStop)
cl.R.Stop()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
@@ -616,8 +684,8 @@ func TestClientReadPacketReadError(t *testing.T) {
func TestClientReadPacketReadUnknown(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
cl.r.Stop()
defer cl.Stop(testClientStop)
cl.R.Stop()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
@@ -646,11 +714,21 @@ func TestClientWritePacket(t *testing.T) {
r.Close()
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop()
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
cl.Stop(testClientStop)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
time.Sleep(time.Millisecond * 5)
require.True(t,
errors.Is(err, testClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
}
}
}
@@ -658,8 +736,8 @@ func TestClientWritePacket(t *testing.T) {
func TestClientWritePacketWriteNoConn(t *testing.T) {
c, _ := net.Pipe()
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
cl.w.SetPos(0, 16)
cl.Stop()
cl.W.SetPos(0, 16)
cl.Stop(testClientStop)
_, err := cl.WritePacket(pkTable[1].packet)
require.Error(t, err)
@@ -669,8 +747,8 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
func TestClientWritePacketWriteError(t *testing.T) {
c, _ := net.Pipe()
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
cl.w.SetPos(0, 16)
cl.w.Stop()
cl.W.SetPos(0, 16)
cl.W.Stop()
_, err := cl.WritePacket(pkTable[1].packet)
require.Error(t, err)
@@ -729,7 +807,7 @@ func TestInflightGetAll(t *testing.T) {
m := cl.Inflight.GetAll()
o := map[uint16]InflightMessage{
2: InflightMessage{},
2: {},
}
require.Equal(t, o, m)
}

View File

@@ -28,6 +28,10 @@ func decodeString(buf []byte, offset int) (string, int, error) {
return "", 0, err
}
if !validUTF8(b) {
return "", 0, ErrOffsetStrInvalidUTF8
}
return bytesToString(b), n, nil
}
@@ -39,12 +43,10 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
}
if next+int(length) > len(buf) {
return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
}
if !validUTF8(buf[next : next+int(length)]) {
return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8
}
// Note: there is no validUTF8() test for []byte payloads
return buf[next : next+int(length)], next + int(length), nil
}

View File

@@ -1,6 +1,8 @@
package packets
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
@@ -19,15 +21,16 @@ func BenchmarkBytesToString(b *testing.B) {
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result []string
result string
offset int
shouldFail bool
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: []string{"a/b/c/d", "a"},
result: "a/b/c/d",
},
{
offset: 14,
@@ -41,36 +44,32 @@ func TestDecodeString(t *testing.T) {
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: []string{"hey"},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: []string{"x/y/z", "!@#$%^&"},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
result: []string{"a/b/c/d", "z"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
result: []string{"a/b/c/d", "x"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
result: []string{"a/b/c/d", "y"},
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
{
offset: 17,
@@ -86,20 +85,26 @@ func TestDecodeString(t *testing.T) {
0, 6, // Will Topic - MSB+LSB
'l',
},
result: []string{"lwt"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
for i, wanted := range expect {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding string [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding string [i:%d]", i)
require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i)
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
@@ -116,7 +121,7 @@ func TestDecodeBytes(t *testing.T) {
result []uint8
next int
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
@@ -132,33 +137,27 @@ func TestDecodeBytes(t *testing.T) {
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 8,
shouldFail: true,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding bytes [i:%d]", i)
continue
}
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding bytes [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
@@ -174,7 +173,7 @@ func TestDecodeByte(t *testing.T) {
rawBytes []byte
result uint8
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
@@ -198,22 +197,23 @@ func TestDecodeByte(t *testing.T) {
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
result: uint8(0x00),
offset: 8,
shouldFail: true,
shouldFail: ErrOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding byte [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i)
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
@@ -229,7 +229,7 @@ func TestDecodeUint16(t *testing.T) {
rawBytes []byte
result uint16
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
@@ -243,22 +243,24 @@ func TestDecodeUint16(t *testing.T) {
},
{
rawBytes: []byte{0, 7, 255, 47},
result: uint16(0x761),
offset: 8,
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding uint16 [i:%d]", i)
continue
}
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding uint16 [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
@@ -274,7 +276,7 @@ func TestDecodeByteBool(t *testing.T) {
rawBytes []byte
result bool
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
@@ -287,20 +289,22 @@ func TestDecodeByteBool(t *testing.T) {
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: true,
shouldFail: ErrOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte bool [i:%d]", i)
continue
}
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding byte bool [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}

View File

@@ -6,20 +6,20 @@ import (
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
encodeLength(buf, int64(fh.Remaining))
}
// decode extracts the specification bits from the header byte.
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(headerByte byte) error {
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int) {
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128

View File

@@ -18,130 +18,130 @@ type fixedHeaderTable struct {
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Connack, false, 0, false, 0},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Publish, false, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Publish, false, 1, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Publish, false, 2, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
flagError: true,
},
}
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int
have int64
want []byte
}{
{

View File

@@ -3,6 +3,8 @@ package packets
import (
"bytes"
"errors"
"fmt"
"strconv"
)
// All of the valid packet types and their packet identifier.
@@ -60,7 +62,6 @@ var (
// PACKETS
ErrProtocolViolation = errors.New("protocol violation")
ErrOffsetStrOutOfRange = errors.New("offset string out of range")
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
@@ -76,40 +77,31 @@ var (
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
PacketID uint16
// Connect
FixedHeader FixedHeader
AllowClients []string // For use with OnMessage event hook.
Topics []string
ReturnCodes []byte
ProtocolName []byte
Qoss []byte
Payload []byte
Username []byte
Password []byte
WillMessage []byte
ClientIdentifier string
TopicName string
WillTopic string
PacketID uint16
Keepalive uint16
ReturnCode byte
ProtocolVersion byte
WillQos byte
ReservedBit byte
CleanSession bool
WillFlag bool
WillQos byte
WillRetain bool
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
Keepalive uint16
ClientIdentifier string
WillTopic string
WillMessage []byte
Username []byte
Password []byte
// Connack
SessionPresent bool
ReturnCode byte
// Publish
TopicName string
Payload []byte
// Subscribe, Unsubscribe
Topics []string
Qoss []byte
ReturnCodes []byte // Suback
SessionPresent bool
}
// ConnectEncode encodes a connect packet.
@@ -169,17 +161,17 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
if err != nil {
return ErrMalformedProtocolName
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedProtocolVersion
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return ErrMalformedFlags
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
@@ -192,25 +184,25 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedKeepalive
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedClientID
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedWillTopic
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedWillMessage
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
}
}
@@ -218,14 +210,14 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.UsernameFlag {
pk.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
}
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedPassword
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
}
@@ -292,12 +284,12 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
if err != nil {
return ErrMalformedSessionPresent
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedReturnCode
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}
return nil
@@ -334,7 +326,7 @@ func (pk *Packet) PubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -352,7 +344,7 @@ func (pk *Packet) PubcompDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -390,14 +382,14 @@ func (pk *Packet) PublishDecode(buf []byte) error {
pk.TopicName, offset, err = decodeString(buf, 0)
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
// If QOS decode Packet ID.
if pk.FixedHeader.Qos > 0 {
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
}
@@ -451,7 +443,7 @@ func (pk *Packet) PubrecDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
@@ -470,7 +462,7 @@ func (pk *Packet) PubrelDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -495,7 +487,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
// Get Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Get Granted QOS flags.
@@ -542,7 +534,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
@@ -552,7 +544,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
var topic string
topic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
pk.Topics = append(pk.Topics, topic)
@@ -560,7 +552,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
var qos byte
qos, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedQoS
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
}
// Ensure QoS byte is within range.
@@ -599,7 +591,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -641,7 +633,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
@@ -649,7 +641,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
var t string
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
if len(t) > 0 {
@@ -671,3 +663,8 @@ func (pk *Packet) UnsubscribeValidate() (byte, error) {
return Accepted, nil
}
// FormatID returns the PacketID field as a decimal integer.
func (pk *Packet) FormatID() string {
return strconv.FormatUint(uint64(pk.PacketID), 10)
}

View File

@@ -6,7 +6,7 @@ type packetTestData struct {
actualBytes []byte // the actual byte array that is created in the event of a byte mutation (eg. MQTT-2.3.1-1 qos/packet id)
packet *Packet // the packet that is expected
desc string // a description of the test
failFirst interface{} // expected fail result to be run immediately after the method is called
failFirst error // expected fail result to be run immediately after the method is called
expect interface{} // generic expected fail result to be checked
isolate bool // isolate can be used to isolate a test
primary bool // primary is a test that should be run using readPackets

View File

@@ -2,6 +2,8 @@ package packets
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/jinzhu/copier"
@@ -74,13 +76,13 @@ func TestConnectDecode(t *testing.T) {
}
require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc)
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficient bytes in packet [i:%d] %s", i, wanted.desc)
pk := &Packet{FixedHeader: FixedHeader{Type: Connect}}
err := pk.ConnectDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -180,7 +182,7 @@ func TestConnackDecode(t *testing.T) {
err := pk.ConnackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -339,7 +341,7 @@ func TestPubackDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -407,7 +409,7 @@ func TestPubcompDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -494,7 +496,7 @@ func TestPublishDecode(t *testing.T) {
err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected fh error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -630,7 +632,7 @@ func TestPubrecDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -703,7 +705,7 @@ func TestPubrelDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -775,7 +777,7 @@ func TestSubackDecode(t *testing.T) {
err := pk.SubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -854,7 +856,7 @@ func TestSubscribeDecode(t *testing.T) {
err := pk.SubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -957,7 +959,7 @@ func TestUnsubackDecode(t *testing.T) {
err := pk.UnsubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -1034,7 +1036,7 @@ func TestUnsubscribeDecode(t *testing.T) {
err := pk.UnsubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -1080,3 +1082,10 @@ func BenchmarkUnsubscribeValidate(b *testing.B) {
pk.UnsubscribeValidate()
}
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}

View File

@@ -27,27 +27,28 @@ func New() *Index {
}
// RetainMessage saves a message payload to the end of a topic branch. Returns
// 1 if a retained message was added, 0 if there was no change, and -1 if the
// retained message was removed.
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *Index) RetainMessage(msg packets.Packet) int64 {
var q int64
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(msg.TopicName)
// If there is a payload, we can store it.
if len(msg.Payload) > 0 {
if n.Message.FixedHeader.Retain == false {
q = 1
}
n.Message = msg
} else {
if n.Message.FixedHeader.Retain == true {
q = -1
}
x.unpoperate(msg.TopicName, "", true)
return 1
}
return q
// Otherwise, we are unsetting it.
// If there was a previous retained message, return -1 instead of 0.
var r int64 = 0
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
r = -1
}
x.unpoperate(msg.TopicName, "", true)
return r
}
// Subscribe creates a subscription filter for a client. Returns true if the
@@ -175,12 +176,12 @@ func (x *Index) Messages(filter string) []packets.Packet {
// Leaf is a child node on the tree.
type Leaf struct {
Message packets.Packet // a message which has been retained for a specific topic.
Key string // the key that was used to create the leaf.
Filter string // the path of the topic filter being matched.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
Filter string // the path of the topic filter being matched.
Message packets.Packet // a message which has been retained for a specific topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who

View File

@@ -113,8 +113,10 @@ func TestRetainMessage(t *testing.T) {
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
q = index.RetainMessage(pk2) // already exsiting
require.Equal(t, int64(0), q)
// The same message already exists, but we're not doing a deep-copy check, so it's considered
// to be a new message.
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
@@ -126,6 +128,14 @@ func TestRetainMessage(t *testing.T) {
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
// Second Delete retained
q = index.RetainMessage(pk3)
require.Equal(t, int64(0), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
}
func BenchmarkRetainMessage(b *testing.B) {

View File

@@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@@ -3,7 +3,7 @@ package auth
// Allow is an auth controller which allows access to all connections and topics.
type Allow struct{}
// Auth returns true if a username and password are acceptable. Allow always
// Authenticate returns true if a username and password are acceptable. Allow always
// returns true.
func (a *Allow) Authenticate(user, password []byte) bool {
return true
@@ -18,7 +18,7 @@ func (a *Allow) ACL(user []byte, topic string, write bool) bool {
// Disallow is an auth controller which disallows access to all connections and topics.
type Disallow struct{}
// Auth returns true if a username and password are acceptable. Disallow always
// Authenticate returns true if a username and password are acceptable. Disallow always
// returns false.
func (d *Disallow) Authenticate(user, password []byte) bool {
return false

View File

@@ -18,11 +18,11 @@ import (
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
address string // the network address to bind to.
listen *http.Server // the http server.
end int64 // ensure the close methods are only called once.}
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -38,10 +38,10 @@ type Listener interface {
// Listeners contains the network listeners for the broker.
type Listeners struct {
sync.RWMutex
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.

View File

@@ -22,11 +22,11 @@ func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
Config *Config // configuration for the listener.
address string // the network address the listener binds to.
Listening bool // indiciate the listener is listening.
Serving bool // indicate the listener is serving.
Config *Config // configuration for the listener.
done chan bool // indicate the listener is done.
Serving bool // indicate the listener is serving.
Listening bool // indiciate the listener is listening.
ErrListen bool // throw an error on listen.
}
@@ -44,15 +44,12 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for {
select {
case <-l.done:
return
}
for range l.done {
return
}
}
// SetConfig sets the configuration values of the mock listener.
// Listen begins listening for incoming traffic.
func (l *MockListener) Listen(s *system.Info) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
@@ -95,7 +92,7 @@ func (l *MockListener) IsServing() bool {
return l.Serving
}
// IsServing indicates whether the mock listener is listening.
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()

View File

@@ -14,11 +14,11 @@ import (
type TCP struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
end int64 // ensure the close methods are only called once.
config *Config // configuration values for the listener.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
var err error
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadInt64(&l.end) == 1 {
if atomic.LoadUint32(&l.end) == 1 {
return
}
@@ -93,8 +95,10 @@ func (l *TCP) Serve(establish EstablishFunc) {
return
}
if atomic.LoadInt64(&l.end) == 0 {
go establish(l.id, conn, l.config.Auth)
if atomic.LoadUint32(&l.end) == 0 {
go func() {
_ = establish(l.id, conn, l.config.Auth)
}()
}
}
}
@@ -104,8 +108,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}

View File

@@ -17,12 +17,14 @@ import (
)
var (
ErrInvalidMessage = errors.New("Message type not binary")
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
@@ -30,11 +32,11 @@ var (
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
end int64 // ensure the close methods are only called once.
establish EstablishFunc // the server's establish conection handler.
establish EstablishFunc // the server's establish connection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
@@ -161,8 +163,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -1,7 +1,7 @@
package bolt
import (
"errors"
"fmt"
"time"
sgob "github.com/asdine/storm/codec/gob"
@@ -12,12 +12,17 @@ import (
)
const (
defaultPath = "mochi.db"
// defaultPath is the default file to use to store the data.
defaultPath = "mochi.db"
// defaultTimeout is the default timeout of the file lock.
defaultTimeout = 250 * time.Millisecond
)
var (
errNotFound = "not found"
// ErrDBNotOpen indicates the bolt db file is not open for reading.
ErrDBNotOpen = fmt.Errorf("boltdb not opened")
)
// Store is a backend for writing and reading to bolt persistent storage.
@@ -64,7 +69,7 @@ func (s *Store) Close() {
// WriteServerInfo writes the server info to the boltdb instance.
func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -77,7 +82,7 @@ func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
// WriteSubscription writes a single subscription to the boltdb instance.
func (s *Store) WriteSubscription(v persistence.Subscription) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -90,7 +95,7 @@ func (s *Store) WriteSubscription(v persistence.Subscription) error {
// WriteInflight writes a single inflight message to the boltdb instance.
func (s *Store) WriteInflight(v persistence.Message) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -103,7 +108,7 @@ func (s *Store) WriteInflight(v persistence.Message) error {
// WriteRetained writes a single retained message to the boltdb instance.
func (s *Store) WriteRetained(v persistence.Message) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -116,7 +121,7 @@ func (s *Store) WriteRetained(v persistence.Message) error {
// WriteClient writes a single client to the boltdb instance.
func (s *Store) WriteClient(v persistence.Client) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -129,7 +134,7 @@ func (s *Store) WriteClient(v persistence.Client) error {
// DeleteSubscription deletes a subscription from the boltdb instance.
func (s *Store) DeleteSubscription(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Subscription{
@@ -145,7 +150,7 @@ func (s *Store) DeleteSubscription(id string) error {
// DeleteClient deletes a client from the boltdb instance.
func (s *Store) DeleteClient(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Client{
@@ -161,7 +166,7 @@ func (s *Store) DeleteClient(id string) error {
// DeleteInflight deletes an inflight message from the boltdb instance.
func (s *Store) DeleteInflight(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
@@ -177,7 +182,7 @@ func (s *Store) DeleteInflight(id string) error {
// DeleteRetained deletes a retained message from the boltdb instance.
func (s *Store) DeleteRetained(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
@@ -193,11 +198,11 @@ func (s *Store) DeleteRetained(id string) error {
// ReadSubscriptions loads all the subscriptions from the boltdb instance.
func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KSubscription, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -207,11 +212,11 @@ func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
// ReadClients loads all the clients from the boltdb instance.
func (s *Store) ReadClients() (v []persistence.Client, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KClient, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -221,11 +226,11 @@ func (s *Store) ReadClients() (v []persistence.Client, err error) {
// ReadInflight loads all the inflight messages from the boltdb instance.
func (s *Store) ReadInflight() (v []persistence.Message, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KInflight, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -235,11 +240,11 @@ func (s *Store) ReadInflight() (v []persistence.Message, err error) {
// ReadRetained loads all the retained messages from the boltdb instance.
func (s *Store) ReadRetained() (v []persistence.Message, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KRetained, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -249,11 +254,11 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
//ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.One("ID", persistence.KServerInfo, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}

View File

@@ -7,11 +7,21 @@ import (
)
const (
// KSubscription is the key for subscription data.
KSubscription = "sub"
KServerInfo = "srv"
KRetained = "ret"
KInflight = "ifm"
KClient = "cl"
// KServerInfo is the key for server info data.
KServerInfo = "srv"
// KRetained is the key for retained messages data.
KRetained = "ret"
// KInflight is the key for inflight messages data.
KInflight = "ifm"
// KClient is the key for client data.
KClient = "cl"
)
// Store is an interface which details a persistent storage connector.
@@ -53,50 +63,50 @@ type Subscription struct {
// Message contains the details of a retained or inflight message.
type Message struct {
ID string // the storage key.
T string // the type of the stored data.
Client string // the id of the client who sent the message (if inflight).
FixedHeader FixedHeader // the header properties of the message.
PacketID uint16 // the unique id of the packet (if inflight).
TopicName string // the topic the message was sent to (if retained).
Payload []byte // the message payload (if retained).
FixedHeader FixedHeader // the header properties of the message.
T string // the type of the stored data.
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
}
// FixedHeader contains the fixed header properties of a message.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Client contains client data that can be persistently stored.
type Client struct {
LWT LWT // the last-will-and-testament message for the client.
Username []byte // the username the client authenticated with.
ID string // the storage key.
ClientID string // the id of the client.
T string // the type of the stored data.
Listener string // the last known listener id for the client
Username []byte // the username the client authenticated with.
LWT LWT // the last-will-and-testament message for the client.
}
// LWT contains details about a clients LWT payload.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
Fail map[string]bool // issue errors for different methods.
}
// Open opens the storage instance.
@@ -197,7 +207,7 @@ func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) {
}
return []Subscription{
Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
@@ -214,7 +224,7 @@ func (s *MockStore) ReadClients() (v []Client, err error) {
}
return []Client{
Client{
{
ID: "cl_client1",
ClientID: "client1",
T: KClient,
@@ -230,7 +240,7 @@ func (s *MockStore) ReadInflight() (v []Message, err error) {
}
return []Message{
Message{
{
ID: "client1_if_100",
T: KInflight,
Client: "client1",
@@ -250,7 +260,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
}
return []Message{
Message{
{
ID: "client1_ret_200",
T: KRetained,
FixedHeader: FixedHeader{

View File

@@ -1,9 +1,10 @@
// packet server provides a MQTT 3.1.1 compliant MQTT server.
// package server provides a MQTT 3.1.1 compliant MQTT server.
package server
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"sync/atomic"
@@ -14,6 +15,7 @@ import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/internal/utils"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence"
@@ -21,14 +23,41 @@ import (
)
const (
Version = "1.0.2" // the server version.
// Version indicates the current server version.
Version = "1.1.1"
)
var (
ErrListenerIDExists = errors.New("Listener id already exists")
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics")
// ErrListenerIDExists indicates that a listener with the same id already exists.
ErrListenerIDExists = errors.New("listener id already exists")
// ErrReadConnectInvalid indicates that the connection packet was invalid.
ErrReadConnectInvalid = errors.New("connect packet was not valid")
// ErrConnectNotAuthorized indicates that the connection packet had incorrect auth values.
ErrConnectNotAuthorized = errors.New("connect packet was not authorized")
// ErrInvalidTopic indicates that the specified topic was not valid.
ErrInvalidTopic = errors.New("cannot publish to $ and $SYS topics")
// ErrRejectPacket indicates that a packet should be dropped instead of processed.
ErrRejectPacket = errors.New("packet rejected")
// ErrClientDisconnect indicates that a client disconnected from the server.
ErrClientDisconnect = errors.New("client disconnected")
// ErrClientReconnect indicates that a client attempted to reconnect while still connected.
ErrClientReconnect = errors.New("client sent connect while connected")
// ErrServerShutdown is propagated when the server shuts down.
ErrServerShutdown = errors.New("server is shutting down")
// ErrSessionReestablished indicates that an existing client was replaced by a newly connected
// client. The existing client is disconnected.
ErrSessionReestablished = errors.New("client session re-established")
// ErrConnectionFailed indicates that a client connection attempt failed for other reasons.
ErrConnectionFailed = errors.New("connection attempt failed")
// SysTopicInterval is the number of milliseconds between $SYS topic publishes.
SysTopicInterval time.Duration = 30000
@@ -44,16 +73,26 @@ var (
// Server is an MQTT broker server. It should be created with server.New()
// in order to ensure all the internal fields are correctly populated.
type Server struct {
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
Store persistence.Store // a persistent storage backend if desired.
done chan bool // indicate that the server is ending.
bytepool circ.BytesPool // a byte pool for incoming and outgoing packets.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
done chan bool // indicate that the server is ending.
}
// Options contains configurable options for the server.
type Options struct {
// BufferSize overrides the default buffer size (circ.DefaultBufferSize) for the client buffers.
BufferSize int
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
BufferBlockSize int
}
// inlineMessages contains channels for handling inline (direct) publishing.
@@ -62,11 +101,22 @@ type inlineMessages struct {
pub chan packets.Packet // a channel of packets to publish to clients
}
// New returns a new instance of an MQTT broker.
// New returns a new instance of MQTT server with no options.
// This method has been deprecated and will be removed in a future release.
// Please use NewServer instead.
func New() *Server {
return NewServer(nil)
}
// NewServer returns a new instance of an MQTT broker with optional values where applicable.
func NewServer(opts *Options) *Server {
if opts == nil {
opts = new(Options)
}
s := &Server{
done: make(chan bool),
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
bytepool: circ.NewBytesPool(opts.BufferSize),
Clients: clients.New(),
Topics: topics.New(),
System: &system.Info{
@@ -78,7 +128,8 @@ func New() *Server {
done: make(chan bool),
pub: make(chan packets.Packet, 1024),
},
Events: events.Events{},
Events: events.Events{},
Options: opts,
}
// Expose server stats using the system listener so it can be used in the
@@ -165,118 +216,200 @@ func (s *Server) inlineClient() {
}
}
// readConnectionPacket reads the first incoming header for a connection, and if
// acceptable, returns the valid connection packet.
func (s *Server) readConnectionPacket(cl *clients.Client) (pk packets.Packet, err error) {
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return
}
pk, err = cl.ReadPacket(fh)
if err != nil {
return
}
if pk.FixedHeader.Type != packets.Connect {
return pk, ErrReadConnectInvalid
}
return
}
// onError is a pass-through method which triggers the OnError
// event hook (if applicable), and returns the provided error.
func (s *Server) onError(cl events.Client, err error) error {
if err == nil {
return err
}
// Note: if the error originates from a real cause, it will
// have been captured as the StopCause. The two cases ignored
// below are ordinary consequences of closing the connection.
// If one of these ordinary conditions stops the connection,
// then the client closed or broke the connection.
if s.Events.OnError != nil &&
!errors.Is(err, io.EOF) {
s.Events.OnError(cl, err)
}
return err
}
// onStorage is a pass-through method which delegates errors from
// the persistent storage adapter to the onError event hook.
func (s *Server) onStorage(cl events.Clientlike, err error) {
if err == nil {
return
}
_ = s.onError(cl.Info(), fmt.Errorf("storage: %w", err))
}
// EstablishConnection establishes a new client when a listener
// accepts a new connection.
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
xbr := s.bytepool.Get() // Get byte buffer from pools for receiving packet data.
xbw := s.bytepool.Get() // and for sending.
defer s.bytepool.Put(xbr)
defer s.bytepool.Put(xbw)
cl := clients.NewClient(c,
circ.NewReaderFromSlice(0, xbr),
circ.NewWriterFromSlice(0, xbw),
circ.NewReaderFromSlice(s.Options.BufferBlockSize, xbr),
circ.NewWriterFromSlice(s.Options.BufferBlockSize, xbw),
s.System,
)
cl.Start()
defer cl.ClearBuffers()
defer cl.Stop(nil)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
pk, err := s.readConnectionPacket(cl)
if err != nil {
return err
return s.onError(cl.Info(), fmt.Errorf("read connection: %w", err))
}
pk, err := cl.ReadPacket(fh)
ackCode, err := pk.ConnectValidate()
if err != nil {
return err
if err := s.ackConnection(cl, ackCode, false); err != nil {
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
}
return s.onError(cl.Info(), fmt.Errorf("validate connection packet: %w", err))
}
if pk.FixedHeader.Type != packets.Connect {
return ErrReadConnectInvalid
}
cl.Identify(lid, pk, ac) // Set client identity values from the connection packet.
cl.Identify(lid, pk, ac)
retcode, _ := pk.ConnectValidate()
if !ac.Authenticate(pk.Username, pk.Password) {
retcode = packets.CodeConnectBadAuthValues
if err := s.ackConnection(cl, packets.CodeConnectBadAuthValues, false); err != nil {
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
}
return s.onError(cl.Info(), ErrConnectionFailed)
}
atomic.AddInt64(&s.System.ConnectionsTotal, 1)
atomic.AddInt64(&s.System.ClientsConnected, 1)
defer atomic.AddInt64(&s.System.ClientsConnected, -1)
defer atomic.AddInt64(&s.System.ClientsDisconnected, 1)
var sessionPresent bool
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
if atomic.LoadInt64(&existing.State.Done) == 1 {
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
}
existing.Stop()
if pk.CleanSession {
for k := range existing.Subscriptions {
delete(existing.Subscriptions, k)
q := s.Topics.Unsubscribe(k, existing.ID)
if q {
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
} else {
cl.Inflight = existing.Inflight // Inherit from existing session.
cl.Subscriptions = existing.Subscriptions
sessionPresent = true
}
existing.Unlock()
} else {
atomic.AddInt64(&s.System.ClientsTotal, 1)
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
atomic.AddInt64(&s.System.ClientsMax, 1)
}
sessionPresent := s.inheritClientSession(pk, cl)
s.Clients.Add(cl)
err = s.ackConnection(cl, ackCode, sessionPresent)
if err != nil {
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
}
s.Clients.Add(cl) // Overwrite any existing client with the same name.
err = s.writeClient(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: sessionPresent,
ReturnCode: retcode,
})
if err != nil || retcode != packets.Accepted {
return err
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
}
s.ResendClientInflight(cl, true)
if s.Store != nil {
s.Store.WriteClient(persistence.Client{
s.onStorage(cl, s.Store.WriteClient(persistence.Client{
ID: "cl_" + cl.ID,
ClientID: cl.ID,
T: persistence.KClient,
Listener: cl.Listener,
Username: cl.Username,
LWT: persistence.LWT(cl.LWT),
})
}))
}
err = cl.Read(s.processPacket)
if err != nil {
s.closeClient(cl, true)
if s.Events.OnConnect != nil {
s.Events.OnConnect(cl.Info(), events.Packet(pk))
}
s.bytepool.Put(xbr) // Return byte buffers to pools when the client has finished.
s.bytepool.Put(xbw)
if err := cl.Read(s.processPacket); err != nil {
s.sendLWT(cl)
cl.Stop(err)
}
atomic.AddInt64(&s.System.ClientsConnected, -1)
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
err = cl.StopCause() // Determine true cause of stop.
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(cl.Info(), err)
}
return err
}
// ackConnection returns a Connack packet to a client.
func (s *Server) ackConnection(cl *clients.Client, ack byte, present bool) error {
return s.writeClient(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: present,
ReturnCode: ack,
})
}
// inheritClientSession inherits the state of an existing client sharing the same
// connection ID. If cleanSession is true, the state of any previously existing client
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) bool {
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
existing.Stop(ErrSessionReestablished) // Issue a stop on the old client.
// Per [MQTT-3.1.2-6]:
// If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one.
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
if pk.CleanSession || existing.CleanSession {
s.unsubscribeClient(existing)
return false
}
cl.Inflight = existing.Inflight // Take address of existing session.
cl.Subscriptions = existing.Subscriptions
return true
} else {
atomic.AddInt64(&s.System.ClientsTotal, 1)
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
atomic.AddInt64(&s.System.ClientsMax, 1)
}
return false
}
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *clients.Client) {
for k := range cl.Subscriptions {
delete(cl.Subscriptions, k)
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
}
// writeClient writes packets to a client connection.
func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
_, err := cl.WritePacket(pk)
if err != nil {
return err
return fmt.Errorf("write: %w", err)
}
return nil
@@ -327,13 +460,14 @@ func (s *Server) processPacket(cl *clients.Client, pk packets.Packet) error {
// establish a new connection on an existing connection. See EstablishConnection
// instead.
func (s *Server) processConnect(cl *clients.Client, pk packets.Packet) error {
s.closeClient(cl, true)
s.sendLWT(cl)
cl.Stop(ErrClientReconnect)
return nil
}
// processDisconnect processes a Disconnect packet.
func (s *Server) processDisconnect(cl *clients.Client, pk packets.Packet) error {
s.closeClient(cl, false)
cl.Stop(ErrClientDisconnect)
return nil
}
@@ -362,14 +496,15 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Type: packets.Publish,
Retain: retain,
},
TopicName: topic,
Payload: payload,
}
if retain {
s.retainMessage(pk)
s.retainMessage(&s.inline, pk)
}
// handoff packet to s.inline.pub channel for writing to client buffers
@@ -379,6 +514,16 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
return nil
}
// Info provides pseudo-client information for the inline messages processor.
// It provides a 'client' to which inline retained messages can be assigned.
func (*inlineMessages) Info() events.Client {
return events.Client{
ID: "inline",
Remote: "inline",
Listener: "inline",
}
}
// processPublish processes a Publish packet.
func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
if len(pk.TopicName) >= 4 && pk.TopicName[0:4] == "$SYS" {
@@ -389,8 +534,25 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
return nil
}
// if an OnProcessMessage hook exists, potentially modify the packet.
if s.Events.OnProcessMessage != nil {
pkx, err := s.Events.OnProcessMessage(cl.Info(), events.Packet(pk))
if err == nil {
pk = packets.Packet(pkx) // Only use the new package changes if there's no errors.
} else {
// If the ErrRejectPacket is return, abandon processing the packet.
if err == ErrRejectPacket {
return nil
}
if s.Events.OnError != nil {
s.Events.OnError(cl.Info(), err)
}
}
}
if pk.FixedHeader.Retain {
s.retainMessage(pk)
s.retainMessage(cl, pk)
}
if pk.FixedHeader.Qos > 0 {
@@ -407,12 +569,12 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// omit errors in case of broken connection / LWT publish. ack send failures
// will be handled by in-flight resending on next reconnect.
s.writeClient(cl, ack)
s.onError(cl.Info(), s.writeClient(cl, ack))
}
// if an OnMessage hook exists, potentially modify the packet.
if s.Events.OnMessage != nil {
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
if pkx, err := s.Events.OnMessage(cl.Info(), events.Packet(pk)); err == nil {
pk = packets.Packet(pkx)
}
}
@@ -425,21 +587,23 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// retainMessage adds a message to a topic, and if a persistent store is provided,
// adds the message to the store so it can be reloaded if necessary.
func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) retainMessage(cl events.Clientlike, pk packets.Packet) {
out := pk.PublishCopy()
q := s.Topics.RetainMessage(out)
atomic.AddInt64(&s.System.Retained, q)
r := s.Topics.RetainMessage(out)
atomic.AddInt64(&s.System.Retained, r)
if s.Store != nil {
if q == 1 {
s.Store.WriteRetained(persistence.Message{
ID: "ret_" + out.TopicName,
id := "ret_" + out.TopicName
if r == 1 {
s.onStorage(cl, s.Store.WriteRetained(persistence.Message{
ID: id,
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
})
}))
} else {
s.Store.DeleteRetained("ret_" + out.TopicName)
s.onStorage(cl, s.Store.DeleteRetained(id))
}
}
}
@@ -449,6 +613,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) publishToSubscribers(pk packets.Packet) {
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
if client, ok := s.Clients.Get(id); ok {
// If the AllowClients value is set, only deliver the packet if the subscribed
// client exists in the AllowClients value. For use with the OnMessage event hook
// in cases where you want to publish messages to clients selectively.
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
continue
}
out := pk.PublishCopy()
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
out.FixedHeader.Qos = qos
@@ -473,18 +645,18 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + client.ID + "_" + strconv.Itoa(int(out.PacketID)),
T: persistence.KRetained,
s.onStorage(client, s.Store.WriteInflight(persistence.Message{
ID: persistentID(client, out),
T: persistence.KInflight,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
Sent: sent,
})
}))
}
}
s.writeClient(client, out)
s.onError(client.Info(), s.writeClient(client, out))
}
}
}
@@ -496,7 +668,7 @@ func (s *Server) processPuback(cl *clients.Client, pk packets.Packet) error {
atomic.AddInt64(&s.System.Inflight, -1)
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
}
@@ -538,7 +710,7 @@ func (s *Server) processPubrel(cl *clients.Client, pk packets.Packet) error {
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
@@ -551,7 +723,7 @@ func (s *Server) processPubcomp(cl *clients.Client, pk packets.Packet) error {
atomic.AddInt64(&s.System.Inflight, -1)
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
}
@@ -571,13 +743,13 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
retCodes[i] = pk.Qoss[i]
if s.Store != nil {
s.Store.WriteSubscription(persistence.Subscription{
s.onStorage(cl, s.Store.WriteSubscription(persistence.Subscription{
ID: "sub_" + cl.ID + ":" + pk.Topics[i],
T: persistence.KSubscription,
Filter: pk.Topics[i],
Client: cl.ID,
QoS: pk.Qoss[i],
})
}))
}
}
}
@@ -593,10 +765,15 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
return err
}
// Publish out any retained messages matching the subscription filter.
// Publish out any retained messages matching the subscription filter and the user has
// been allowed to subscribe to.
for i := 0; i < len(pk.Topics); i++ {
if retCodes[i] == packets.ErrSubAckNetworkError {
continue
}
for _, pkv := range s.Topics.Messages(pk.Topics[i]) {
s.writeClient(cl, pkv) // omit errors, prefer continuing.
s.onError(cl.Info(), s.writeClient(cl, pkv))
}
}
@@ -626,6 +803,17 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
return nil
}
// atomicItoa reads an *int64 and formats a decimal string.
func atomicItoa(ptr *int64) string {
return strconv.FormatInt(atomic.LoadInt64(ptr), 10)
}
// persistentID return a string combining the client and packet
// identifiers for use with the persistence layer.
func persistentID(client *clients.Client, pk packets.Packet) string {
return "if_" + client.ID + "_" + pk.FormatID()
}
// publishSysTopics publishes the current values to the server $SYS topics.
// Due to the int to string conversions this method is not as cheap as
// some of the others so the publishing interval should be set appropriately.
@@ -637,26 +825,27 @@ func (s *Server) publishSysTopics() {
},
}
s.System.Uptime = time.Now().Unix() - s.System.Started
uptime := time.Now().Unix() - atomic.LoadInt64(&s.System.Started)
atomic.StoreInt64(&s.System.Uptime, uptime)
topics := map[string]string{
"$SYS/broker/version": s.System.Version,
"$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)),
"$SYS/broker/timestamp": strconv.Itoa(int(s.System.Started)),
"$SYS/broker/load/bytes/received": strconv.Itoa(int(s.System.BytesRecv)),
"$SYS/broker/load/bytes/sent": strconv.Itoa(int(s.System.BytesSent)),
"$SYS/broker/clients/connected": strconv.Itoa(int(s.System.ClientsConnected)),
"$SYS/broker/clients/disconnected": strconv.Itoa(int(s.System.ClientsDisconnected)),
"$SYS/broker/clients/maximum": strconv.Itoa(int(s.System.ClientsMax)),
"$SYS/broker/clients/total": strconv.Itoa(int(s.System.ClientsTotal)),
"$SYS/broker/connections/total": strconv.Itoa(int(s.System.ConnectionsTotal)),
"$SYS/broker/messages/received": strconv.Itoa(int(s.System.MessagesRecv)),
"$SYS/broker/messages/sent": strconv.Itoa(int(s.System.MessagesSent)),
"$SYS/broker/messages/publish/dropped": strconv.Itoa(int(s.System.PublishDropped)),
"$SYS/broker/messages/publish/received": strconv.Itoa(int(s.System.PublishRecv)),
"$SYS/broker/messages/publish/sent": strconv.Itoa(int(s.System.PublishSent)),
"$SYS/broker/messages/retained/count": strconv.Itoa(int(s.System.Retained)),
"$SYS/broker/messages/inflight": strconv.Itoa(int(s.System.Inflight)),
"$SYS/broker/subscriptions/count": strconv.Itoa(int(s.System.Subscriptions)),
"$SYS/broker/uptime": atomicItoa(&s.System.Uptime),
"$SYS/broker/timestamp": atomicItoa(&s.System.Started),
"$SYS/broker/load/bytes/received": atomicItoa(&s.System.BytesRecv),
"$SYS/broker/load/bytes/sent": atomicItoa(&s.System.BytesSent),
"$SYS/broker/clients/connected": atomicItoa(&s.System.ClientsConnected),
"$SYS/broker/clients/disconnected": atomicItoa(&s.System.ClientsDisconnected),
"$SYS/broker/clients/maximum": atomicItoa(&s.System.ClientsMax),
"$SYS/broker/clients/total": atomicItoa(&s.System.ClientsTotal),
"$SYS/broker/connections/total": atomicItoa(&s.System.ConnectionsTotal),
"$SYS/broker/messages/received": atomicItoa(&s.System.MessagesRecv),
"$SYS/broker/messages/sent": atomicItoa(&s.System.MessagesSent),
"$SYS/broker/messages/publish/dropped": atomicItoa(&s.System.PublishDropped),
"$SYS/broker/messages/publish/received": atomicItoa(&s.System.PublishRecv),
"$SYS/broker/messages/publish/sent": atomicItoa(&s.System.PublishSent),
"$SYS/broker/messages/retained/count": atomicItoa(&s.System.Retained),
"$SYS/broker/messages/inflight": atomicItoa(&s.System.Inflight),
"$SYS/broker/subscriptions/count": atomicItoa(&s.System.Subscriptions),
}
for topic, payload := range topics {
@@ -668,10 +857,10 @@ func (s *Server) publishSysTopics() {
}
if s.Store != nil {
s.Store.WriteServerInfo(persistence.ServerInfo{
s.onStorage(&s.inline, s.Store.WriteServerInfo(persistence.ServerInfo{
Info: *s.System,
ID: persistence.KServerInfo,
})
}))
}
}
@@ -691,7 +880,7 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, tk.Packet)))
}
continue
@@ -715,15 +904,15 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
}
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)),
T: persistence.KRetained,
s.onStorage(cl, s.Store.WriteInflight(persistence.Message{
ID: persistentID(cl, tk.Packet),
T: persistence.KInflight,
FixedHeader: persistence.FixedHeader(tk.Packet.FixedHeader),
TopicName: tk.Packet.TopicName,
Payload: tk.Packet.Payload,
Sent: tk.Sent,
Resends: tk.Resends,
})
}))
}
}
@@ -746,15 +935,14 @@ func (s *Server) Close() error {
func (s *Server) closeListenerClients(listener string) {
clients := s.Clients.GetByListener(listener)
for _, cl := range clients {
s.closeClient(cl, false) // omit errors
cl.Stop(ErrServerShutdown)
}
}
// closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
if sendLWT && cl.LWT.Topic != "" {
s.processPublish(cl, packets.Packet{
// sendLWT issues an LWT message to a topic when a client disconnects.
func (s *Server) sendLWT(cl *clients.Client) error {
if cl.LWT.Topic != "" {
err := s.processPublish(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: cl.LWT.Retain,
@@ -763,10 +951,11 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
TopicName: cl.LWT.Topic,
Payload: cl.LWT.Message,
})
if err != nil {
return s.onError(cl.Info(), fmt.Errorf("send lwt: %s %w; %+v", cl.ID, err, cl.LWT))
}
}
cl.Stop()
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
package system
import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestInfoAlignment(t *testing.T) {
typ := reflect.TypeOf(Info{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
switch f.Type.Kind() {
case reflect.Int64, reflect.Uint64:
require.Equalf(t, uintptr(0), f.Offset%8,
"%s requires 64-bit alignment for atomic: offset %d",
f.Name, f.Offset)
}
}
}

View File

@@ -6,6 +6,13 @@
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
---
⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)**
---
### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
@@ -30,35 +37,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View File

@@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
}
// A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
// NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@@ -65,6 +73,8 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
req := &http.Request{
Method: "GET",
Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
@@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
@@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
if u.Scheme == "https" {
if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
@@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
@@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
@@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
return cfg.Clone()
}

View File

@@ -1,16 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@@ -1,38 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
@@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true
} else {
errors = append(errors, "RSV1 set")
}
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
if rsv2 {
errors = append(errors, "RSV2 set")
}
if rsv3 {
errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
errors = append(errors, "len > 125 for control")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}
// 3. Read and parse frame length as per
@@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

View File

@@ -1,15 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

View File

@@ -1,18 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View File

@@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build !appengine
// +build !appengine
package websocket

View File

@@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build appengine
// +build appengine
package websocket

View File

@@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
connectReq := &http.Request{
Method: "CONNECT",
Method: http.MethodConnect,
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,

View File

@@ -23,6 +23,8 @@ func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
@@ -115,8 +117,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-WebSocket-Protocol).
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
@@ -131,7 +133,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}

21
vendor/github.com/gorilla/websocket/tls_handshake.go generated vendored Normal file
View File

@@ -0,0 +1,21 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,21 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@@ -1,19 +0,0 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

View File

@@ -1,12 +0,0 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View File

@@ -24,6 +24,13 @@ const (
// Denotes that the value as been copied
hasCopied
// Some default converter types for a nicer syntax
String string = ""
Bool bool = false
Int int = 0
Float32 float32 = 0
Float64 float64 = 0
)
// Option sets copy options
@@ -32,6 +39,18 @@ type Option struct {
// struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
IgnoreEmpty bool
DeepCopy bool
Converters []TypeConverter
}
type TypeConverter struct {
SrcType interface{}
DstType interface{}
Fn func(src interface{}) (interface{}, error)
}
type converterPair struct {
SrcType reflect.Type
DstType reflect.Type
}
// Tag Flags
@@ -59,12 +78,27 @@ func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err
func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) {
var (
isSlice bool
amount = 1
from = indirect(reflect.ValueOf(fromValue))
to = indirect(reflect.ValueOf(toValue))
isSlice bool
amount = 1
from = indirect(reflect.ValueOf(fromValue))
to = indirect(reflect.ValueOf(toValue))
converters map[converterPair]TypeConverter
)
// save convertes into map for faster lookup
for i := range opt.Converters {
if converters == nil {
converters = make(map[converterPair]TypeConverter)
}
pair := converterPair{
SrcType: reflect.TypeOf(opt.Converters[i].SrcType),
DstType: reflect.TypeOf(opt.Converters[i].DstType),
}
converters[pair] = opt.Converters[i]
}
if !to.CanAddr() {
return ErrInvalidCopyDestination
}
@@ -113,13 +147,16 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
for _, k := range from.MapKeys() {
toKey := indirect(reflect.New(toType.Key()))
if !set(toKey, k, opt.DeepCopy) {
if !set(toKey, k, opt.DeepCopy, converters) {
return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
}
elemType, _ := indirectType(toType.Elem())
elemType := toType.Elem()
if elemType.Kind() != reflect.Slice {
elemType, _ = indirectType(elemType)
}
toValue := indirect(reflect.New(elemType))
if !set(toValue, from.MapIndex(k), opt.DeepCopy) {
if !set(toValue, from.MapIndex(k), opt.DeepCopy, converters) {
if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
return err
}
@@ -148,7 +185,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
}
if !set(to.Index(i), from.Index(i), opt.DeepCopy) {
if !set(to.Index(i), from.Index(i), opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
if err != nil {
@@ -203,6 +240,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
// check source
if source.IsValid() {
copyUnexportedStructFields(dest, source)
// Copy from source field to dest field or method
fromTypeFields := deepFields(fromType)
for _, field := range fromTypeFields {
@@ -249,7 +288,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
toField := dest.FieldByName(destFieldName)
if toField.IsValid() {
if toField.CanSet() {
if !set(toField, fromField, opt.DeepCopy) {
if !set(toField, fromField, opt.DeepCopy, converters) {
if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
return err
}
@@ -291,7 +330,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() {
values := fromMethod.Call([]reflect.Value{})
if len(values) >= 1 {
set(toField, values[0], opt.DeepCopy)
set(toField, values[0], opt.DeepCopy, converters)
}
}
}
@@ -303,7 +342,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
if to.Len() < i+1 {
to.Set(reflect.Append(to, dest.Addr()))
} else {
if !set(to.Index(i), dest.Addr(), opt.DeepCopy) {
if !set(to.Index(i), dest.Addr(), opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt)
if err != nil {
@@ -315,7 +354,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
if to.Len() < i+1 {
to.Set(reflect.Append(to, dest))
} else {
if !set(to.Index(i), dest, opt.DeepCopy) {
if !set(to.Index(i), dest, opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt)
if err != nil {
@@ -334,6 +373,24 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
return
}
func copyUnexportedStructFields(to, from reflect.Value) {
if from.Kind() != reflect.Struct || to.Kind() != reflect.Struct || !from.Type().AssignableTo(to.Type()) {
return
}
// create a shallow copy of 'to' to get all fields
tmp := indirect(reflect.New(to.Type()))
tmp.Set(from)
// revert exported fields
for i := 0; i < to.NumField(); i++ {
if tmp.Field(i).CanSet() {
tmp.Field(i).Set(to.Field(i))
}
}
to.Set(tmp)
}
func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
if !ignoreEmpty {
return false
@@ -352,10 +409,10 @@ func deepFields(reflectType reflect.Type) []reflect.StructField {
// field name. It is empty for upper case (exported) field names.
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
if v.PkgPath == "" {
fields = append(fields, v)
if v.Anonymous {
// also consider fields of anonymous fields as fields of the root
fields = append(fields, deepFields(v.Type)...)
} else {
fields = append(fields, v)
}
}
}
@@ -381,8 +438,14 @@ func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
return reflectType, isPtr
}
func set(to, from reflect.Value, deepCopy bool) bool {
func set(to, from reflect.Value, deepCopy bool, converters map[converterPair]TypeConverter) bool {
if from.IsValid() {
if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil {
return false
} else if ok {
return true
}
if to.Kind() == reflect.Ptr {
// set `to` to nil if from is nil
if from.Kind() == reflect.Ptr && from.IsNil() {
@@ -416,6 +479,9 @@ func set(to, from reflect.Value, deepCopy bool) bool {
toKind = reflect.TypeOf(to.Interface()).Kind()
}
}
if from.Kind() == reflect.Ptr && from.IsNil() {
return true
}
if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
return false
}
@@ -457,7 +523,7 @@ func set(to, from reflect.Value, deepCopy bool) bool {
to.Set(rv)
}
} else if from.Kind() == reflect.Ptr {
return set(to, from.Elem(), deepCopy)
return set(to, from.Elem(), deepCopy, converters)
} else {
return false
}
@@ -466,6 +532,33 @@ func set(to, from reflect.Value, deepCopy bool) bool {
return true
}
// lookupAndCopyWithConverter looks up the type pair, on success the TypeConverter Fn func is called to copy src to dst field.
func lookupAndCopyWithConverter(to, from reflect.Value, converters map[converterPair]TypeConverter) (copied bool, err error) {
pair := converterPair{
SrcType: from.Type(),
DstType: to.Type(),
}
if cnv, ok := converters[pair]; ok {
result, err := cnv.Fn(from.Interface())
if err != nil {
return false, err
}
if result != nil {
to.Set(reflect.ValueOf(result))
} else {
// in case we've got a nil value to copy
to.Set(reflect.Zero(to.Type()))
}
return true, nil
}
return false, nil
}
// parseTags Parses struct tags and returns uint8 bit flags.
func parseTags(tag string) (flg uint8, name string, err error) {
for _, t := range strings.Split(tag, ",") {

1
vendor/github.com/rs/xid/README.md generated vendored
View File

@@ -69,6 +69,7 @@ References:
- Rust port by [Jérôme Renard](https://github.com/jeromer/): https://github.com/jeromer/libxid
- Ruby port by [Valar](https://github.com/valarpirai/): https://github.com/valarpirai/ruby_xid
- Java port by [0xShamil](https://github.com/0xShamil/): https://github.com/0xShamil/java-xid
- Dart port by [Peter Bwire](https://github.com/pitabwire): https://pub.dev/packages/xid
## Install

11
vendor/github.com/rs/xid/error.go generated vendored Normal file
View File

@@ -0,0 +1,11 @@
package xid
const (
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
ErrInvalidID strErr = "xid: invalid ID"
)
// strErr allows declaring errors as constants.
type strErr string
func (err strErr) Error() string { return string(err) }

26
vendor/github.com/rs/xid/id.go generated vendored
View File

@@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
@@ -73,9 +72,6 @@ const (
)
var (
// ErrInvalidID is returned when trying to unmarshal an invalid ID
ErrInvalidID = errors.New("xid: invalid ID")
// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
@@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
decode(id, text)
if !decode(id, text) {
return ErrInvalidID
}
return nil
}
@@ -253,11 +251,15 @@ func (id *ID) UnmarshalJSON(b []byte) error {
*id = nilID
return nil
}
// Check the slice length to prevent panic on passing it to UnmarshalText()
if len(b) < 2 {
return ErrInvalidID
}
return id.UnmarshalText(b[1 : len(b)-1])
}
// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
// decode by unrolling the stdlib base32 algorithm + customized safe check.
func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]
@@ -273,6 +275,16 @@ func decode(id *ID, src []byte) {
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
// Validate that there are no discarer bits (padding) in src that would
// cause the string-encoded id not to equal src.
var check [4]byte
check[3] = encoding[(id[11]<<4)&0x1F]
check[2] = encoding[(id[11]>>1)&0x1F]
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
check[0] = encoding[id[10]>>3]
return bytes.Equal([]byte(src[16:20]), check[:])
}
// Time returns the timestamp part of the id.

View File

@@ -3,6 +3,7 @@ package assert
import (
"fmt"
"reflect"
"time"
)
type CompareType int
@@ -30,6 +31,8 @@ var (
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
}
return compareEqual, false
@@ -310,7 +334,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
@@ -320,7 +347,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
@@ -329,7 +359,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
@@ -339,7 +372,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
@@ -347,8 +383,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
@@ -356,8 +395,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs)
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatability
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {

View File

@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
return ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {

View File

@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -718,10 +718,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
// return (false, false) if impossible.
// return (true, false) if element was not found.
// return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) {
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind()
listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() {
if e := recover(); e != nil {
ok = false
@@ -764,7 +768,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
@@ -787,7 +791,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
}
@@ -831,7 +835,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -852,7 +856,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper()
}
if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
@@ -875,7 +879,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -1000,27 +1004,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}, string) {
didPanic := false
var message interface{}
var stack string
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
stack = string(debug.Stack())
}
}()
// call the target function
f()
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}()
return didPanic, message, stack
// call the target function
f()
didPanic = false
return
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -1161,11 +1159,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual)
if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
}
if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
return Fail(t, "Expected must not be NaN", msgAndArgs...)
}
if math.IsNaN(bf) {
@@ -1188,7 +1190,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1250,8 +1252,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected)
if !aok {
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
bf, bok := toFloat(actual)
if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
}
if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN")
@@ -1259,10 +1265,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
}
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
}
if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN")
}
@@ -1298,7 +1300,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1375,6 +1377,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1588,12 +1611,17 @@ func diff(expected interface{}, actual interface{}) string {
}
var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
switch et {
case reflect.TypeOf(""):
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1625,6 +1653,14 @@ var spewConfig = spew.ConfigState{
MaxDepth: 10,
}
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface {
Helper()
}

View File

@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
t.FailNow()
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
return
}
t.FailNow()
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
return
}
t.FailNow()
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {

View File

@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {

2
vendor/go.etcd.io/bbolt/.gitignore generated vendored
View File

@@ -3,5 +3,3 @@
*.swp
/bin/
cover.out
/.idea
*.iml

View File

@@ -4,10 +4,9 @@ go_import_path: go.etcd.io/bbolt
sudo: false
go:
- 1.15
- 1.12
before_install:
- go get -v golang.org/x/sys/unix
- go get -v honnef.co/go/tools/...
- go get -v github.com/kisielk/errcheck

2
vendor/go.etcd.io/bbolt/Makefile generated vendored
View File

@@ -2,6 +2,8 @@ BRANCH=`git rev-parse --abbrev-ref HEAD`
COMMIT=`git rev-parse --short HEAD`
GOLDFLAGS="-X main.branch $(BRANCH) -X main.commit $(COMMIT)"
default: build
race:
@TEST_FREELIST_TYPE=hashmap go test -v -race -test.run="TestSimulate_(100op|1000op)"
@echo "array freelist test"

5
vendor/go.etcd.io/bbolt/README.md generated vendored
View File

@@ -908,14 +908,12 @@ Below is a list of public, open source projects that use Bolt:
* [BoltStore](https://github.com/yosssi/boltstore) - Session store using Bolt.
* [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners.
* [BoltDbWeb](https://github.com/evnix/boltdbweb) - A web based GUI for BoltDB files.
* [BoltDB Viewer](https://github.com/zc310/rich_boltdb) - A BoltDB Viewer Can run on WindowsLinuxAndroid system.
* [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend.
* [btcwallet](https://github.com/btcsuite/btcwallet) - A bitcoin wallet.
* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining
simple tx and key scans.
* [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend.
* [ChainStore](https://github.com/pressly/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations.
* [🌰 Chestnut](https://github.com/jrapoport/chestnut) - Chestnut is encrypted storage for Go.
* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware.
* [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb.
* [dcrwallet](https://github.com/decred/dcrwallet) - A wallet for the Decred cryptocurrency.
@@ -940,8 +938,9 @@ Below is a list of public, open source projects that use Bolt:
* [MetricBase](https://github.com/msiebuhr/MetricBase) - Single-binary version of Graphite.
* [MuLiFS](https://github.com/dankomiocevic/mulifs) - Music Library Filesystem creates a filesystem to organise your music files.
* [NATS](https://github.com/nats-io/nats-streaming-server) - NATS Streaming uses bbolt for message and metadata storage.
* [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard.
* [photosite/session](https://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site.
* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system.
* [Rain](https://github.com/cenkalti/rain) - BitTorrent client and library.
* [reef-pi](https://github.com/reef-pi/reef-pi) - reef-pi is an award winning, modular, DIY reef tank controller using easy to learn electronics based on a Raspberry Pi.
* [Request Baskets](https://github.com/darklynx/request-baskets) - A web service to collect arbitrary HTTP requests and inspect them via REST API or simple web UI, similar to [RequestBin](http://requestb.in/) service
* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read.

17
vendor/go.etcd.io/bbolt/bolt_unix.go generated vendored
View File

@@ -7,8 +7,6 @@ import (
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
)
// flock acquires an advisory lock on a file descriptor.
@@ -51,13 +49,13 @@ func funlock(db *DB) error {
// mmap memory maps a DB's data file.
func mmap(db *DB, sz int) error {
// Map the data file to memory.
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
if err != nil {
return err
}
// Advise the kernel that the mmap is accessed randomly.
err = unix.Madvise(b, syscall.MADV_RANDOM)
err = madvise(b, syscall.MADV_RANDOM)
if err != nil && err != syscall.ENOSYS {
// Ignore not implemented error in kernel because it still works.
return fmt.Errorf("madvise: %s", err)
@@ -78,9 +76,18 @@ func munmap(db *DB) error {
}
// Unmap using the original byte slice.
err := unix.Munmap(db.dataref)
err := syscall.Munmap(db.dataref)
db.dataref = nil
db.data = nil
db.datasz = 0
return err
}
// NOTE: This function is copied from stdlib because it is not available on darwin.
func madvise(b []byte, advice int) (err error) {
_, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice))
if e1 != 0 {
err = e1
}
return
}

114
vendor/go.etcd.io/bbolt/compact.go generated vendored
View File

@@ -1,114 +0,0 @@
package bbolt
// Compact will create a copy of the source DB and in the destination DB. This may
// reclaim space that the source database no longer has use for. txMaxSize can be
// used to limit the transactions size of this process and may trigger intermittent
// commits. A value of zero will ignore transaction sizes.
// TODO: merge with: https://github.com/etcd-io/etcd/blob/b7f0f52a16dbf83f18ca1d803f7892d750366a94/mvcc/backend/backend.go#L349
func Compact(dst, src *DB, txMaxSize int64) error {
// commit regularly, or we'll run out of memory for large datasets if using one transaction.
var size int64
tx, err := dst.Begin(true)
if err != nil {
return err
}
defer tx.Rollback()
if err := walk(src, func(keys [][]byte, k, v []byte, seq uint64) error {
// On each key/value, check if we have exceeded tx size.
sz := int64(len(k) + len(v))
if size+sz > txMaxSize && txMaxSize != 0 {
// Commit previous transaction.
if err := tx.Commit(); err != nil {
return err
}
// Start new transaction.
tx, err = dst.Begin(true)
if err != nil {
return err
}
size = 0
}
size += sz
// Create bucket on the root transaction if this is the first level.
nk := len(keys)
if nk == 0 {
bkt, err := tx.CreateBucket(k)
if err != nil {
return err
}
if err := bkt.SetSequence(seq); err != nil {
return err
}
return nil
}
// Create buckets on subsequent levels, if necessary.
b := tx.Bucket(keys[0])
if nk > 1 {
for _, k := range keys[1:] {
b = b.Bucket(k)
}
}
// Fill the entire page for best compaction.
b.FillPercent = 1.0
// If there is no value then this is a bucket call.
if v == nil {
bkt, err := b.CreateBucket(k)
if err != nil {
return err
}
if err := bkt.SetSequence(seq); err != nil {
return err
}
return nil
}
// Otherwise treat it as a key/value pair.
return b.Put(k, v)
}); err != nil {
return err
}
return tx.Commit()
}
// walkFunc is the type of the function called for keys (buckets and "normal"
// values) discovered by Walk. keys is the list of keys to descend to the bucket
// owning the discovered key/value pair k/v.
type walkFunc func(keys [][]byte, k, v []byte, seq uint64) error
// walk walks recursively the bolt database db, calling walkFn for each key it finds.
func walk(db *DB, walkFn walkFunc) error {
return db.View(func(tx *Tx) error {
return tx.ForEach(func(name []byte, b *Bucket) error {
return walkBucket(b, nil, name, nil, b.Sequence(), walkFn)
})
})
}
func walkBucket(b *Bucket, keypath [][]byte, k, v []byte, seq uint64, fn walkFunc) error {
// Execute callback.
if err := fn(keypath, k, v, seq); err != nil {
return err
}
// If this is not a bucket then stop.
if v != nil {
return nil
}
// Iterate over each child key/value.
keypath = append(keypath, k)
return b.ForEach(func(k, v []byte) error {
if v == nil {
bkt := b.Bucket(k)
return walkBucket(bkt, keypath, k, nil, bkt.Sequence(), fn)
}
return walkBucket(b, keypath, k, v, b.Sequence(), fn)
})
}

66
vendor/go.etcd.io/bbolt/db.go generated vendored
View File

@@ -120,12 +120,6 @@ type DB struct {
// of truncate() and fsync() when growing the data file.
AllocSize int
// Mlock locks database file in memory when set to true.
// It prevents major page faults, however used memory can't be reclaimed.
//
// Supported only on Unix via mlock/munlock syscalls.
Mlock bool
path string
openFile func(string, int, os.FileMode) (*os.File, error)
file *os.File
@@ -194,7 +188,6 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) {
db.MmapFlags = options.MmapFlags
db.NoFreelistSync = options.NoFreelistSync
db.FreelistType = options.FreelistType
db.Mlock = options.Mlock
// Set default values for later DB operations.
db.MaxBatchSize = DefaultMaxBatchSize
@@ -344,8 +337,7 @@ func (db *DB) mmap(minsz int) error {
}
// Ensure the size is at least the minimum size.
fileSize := int(info.Size())
var size = fileSize
var size = int(info.Size())
if size < minsz {
size = minsz
}
@@ -354,13 +346,6 @@ func (db *DB) mmap(minsz int) error {
return err
}
if db.Mlock {
// Unlock db memory
if err := db.munlock(fileSize); err != nil {
return err
}
}
// Dereference all mmap references before unmapping.
if db.rwtx != nil {
db.rwtx.root.dereference()
@@ -376,13 +361,6 @@ func (db *DB) mmap(minsz int) error {
return err
}
if db.Mlock {
// Don't allow swapping of data file
if err := db.mlock(fileSize); err != nil {
return err
}
}
// Save references to the meta pages.
db.meta0 = db.page(0).meta()
db.meta1 = db.page(1).meta()
@@ -444,36 +422,12 @@ func (db *DB) mmapSize(size int) (int, error) {
return int(sz), nil
}
func (db *DB) munlock(fileSize int) error {
if err := munlock(db, fileSize); err != nil {
return fmt.Errorf("munlock error: " + err.Error())
}
return nil
}
func (db *DB) mlock(fileSize int) error {
if err := mlock(db, fileSize); err != nil {
return fmt.Errorf("mlock error: " + err.Error())
}
return nil
}
func (db *DB) mrelock(fileSizeFrom, fileSizeTo int) error {
if err := db.munlock(fileSizeFrom); err != nil {
return err
}
if err := db.mlock(fileSizeTo); err != nil {
return err
}
return nil
}
// init creates a new database file and initializes its meta pages.
func (db *DB) init() error {
// Create two meta pages on a buffer.
buf := make([]byte, db.pageSize*4)
for i := 0; i < 2; i++ {
p := db.pageInBuffer(buf, pgid(i))
p := db.pageInBuffer(buf[:], pgid(i))
p.id = pgid(i)
p.flags = metaPageFlag
@@ -490,13 +444,13 @@ func (db *DB) init() error {
}
// Write an empty freelist at page 3.
p := db.pageInBuffer(buf, pgid(2))
p := db.pageInBuffer(buf[:], pgid(2))
p.id = pgid(2)
p.flags = freelistPageFlag
p.count = 0
// Write an empty leaf page at page 4.
p = db.pageInBuffer(buf, pgid(3))
p = db.pageInBuffer(buf[:], pgid(3))
p.id = pgid(3)
p.flags = leafPageFlag
p.count = 0
@@ -508,7 +462,6 @@ func (db *DB) init() error {
if err := fdatasync(db); err != nil {
return err
}
db.filesz = len(buf)
return nil
}
@@ -1020,12 +973,6 @@ func (db *DB) grow(sz int) error {
if err := db.file.Sync(); err != nil {
return fmt.Errorf("file sync error: %s", err)
}
if db.Mlock {
// unlock old file and lock new one
if err := db.mrelock(db.filesz, sz); err != nil {
return fmt.Errorf("mlock/munlock error: %s", err)
}
}
}
db.filesz = sz
@@ -1117,11 +1064,6 @@ type Options struct {
// OpenFile is used to open files. It defaults to os.OpenFile. This option
// is useful for writing hermetic tests.
OpenFile func(string, int, os.FileMode) (*os.File, error)
// Mlock locks database file in memory when set to true.
// It prevents potential page faults, however
// used memory can't be reclaimed. (UNIX only)
Mlock bool
}
// DefaultOptions represent the options used if nil options are passed into Open().

View File

@@ -4,7 +4,7 @@ import "sort"
// hashmapFreeCount returns count of free pages(hashmap version)
func (f *freelist) hashmapFreeCount() int {
// use the forwardMap to get the total count
// use the forwardmap to get the total count
count := 0
for _, size := range f.forwardMap {
count += int(size)
@@ -41,7 +41,7 @@ func (f *freelist) hashmapAllocate(txid txid, n int) pgid {
for pid := range bm {
// remove the initial
f.delSpan(pid, size)
f.delSpan(pid, uint64(size))
f.allocs[pid] = txid
@@ -51,7 +51,7 @@ func (f *freelist) hashmapAllocate(txid txid, n int) pgid {
f.addSpan(pid+pgid(n), remain)
for i := pgid(0); i < pgid(n); i++ {
delete(f.cache, pid+i)
delete(f.cache, pid+pgid(i))
}
return pid
}

View File

@@ -1,36 +0,0 @@
// +build !windows
package bbolt
import "golang.org/x/sys/unix"
// mlock locks memory of db file
func mlock(db *DB, fileSize int) error {
sizeToLock := fileSize
if sizeToLock > db.datasz {
// Can't lock more than mmaped slice
sizeToLock = db.datasz
}
if err := unix.Mlock(db.dataref[:sizeToLock]); err != nil {
return err
}
return nil
}
//munlock unlocks memory of db file
func munlock(db *DB, fileSize int) error {
if db.dataref == nil {
return nil
}
sizeToUnlock := fileSize
if sizeToUnlock > db.datasz {
// Can't unlock more than mmaped slice
sizeToUnlock = db.datasz
}
if err := unix.Munlock(db.dataref[:sizeToUnlock]); err != nil {
return err
}
return nil
}

View File

@@ -1,11 +0,0 @@
package bbolt
// mlock locks memory of db file
func mlock(_ *DB, _ int) error {
panic("mlock is supported only on UNIX systems")
}
//munlock unlocks memory of db file
func munlock(_ *DB, _ int) error {
panic("munlock is supported only on UNIX systems")
}

3
vendor/go.etcd.io/bbolt/tx.go generated vendored
View File

@@ -188,6 +188,7 @@ func (tx *Tx) Commit() error {
}
// If strict mode is enabled then perform a consistency check.
// Only the first consistency error is reported in the panic.
if tx.db.StrictMode {
ch := tx.Check()
var errs []string
@@ -392,7 +393,7 @@ func (tx *Tx) CopyFile(path string, mode os.FileMode) error {
return err
}
_, err = tx.WriteTo(f)
err = tx.Copy(f)
if err != nil {
_ = f.Close()
return err

10
vendor/modules.txt vendored
View File

@@ -12,10 +12,10 @@ github.com/asdine/storm/v3/q
# github.com/davecgh/go-spew v1.1.1
## explicit
github.com/davecgh/go-spew/spew
# github.com/gorilla/websocket v1.4.2
# github.com/gorilla/websocket v1.5.0
## explicit; go 1.12
github.com/gorilla/websocket
# github.com/jinzhu/copier v0.3.4
# github.com/jinzhu/copier v0.3.5
## explicit; go 1.13
github.com/jinzhu/copier
# github.com/logrusorgru/aurora v2.0.3+incompatible
@@ -24,14 +24,14 @@ github.com/logrusorgru/aurora
# github.com/pmezard/go-difflib v1.0.0
## explicit
github.com/pmezard/go-difflib/difflib
# github.com/rs/xid v1.3.0
# github.com/rs/xid v1.4.0
## explicit; go 1.12
github.com/rs/xid
# github.com/stretchr/testify v1.7.0
# github.com/stretchr/testify v1.7.1
## explicit; go 1.13
github.com/stretchr/testify/assert
github.com/stretchr/testify/require
# go.etcd.io/bbolt v1.3.6
# go.etcd.io/bbolt v1.3.5
## explicit; go 1.12
go.etcd.io/bbolt
# golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d