Compare commits

..

144 Commits

Author SHA1 Message Date
JB
3a15cc3add Merge pull request #69 from mochi-co/v1.2.0
V1.2.0
2022-04-10 20:23:43 +01:00
mochi
765f6e7c2e use NewServer instead of New 2022-04-10 20:19:06 +01:00
mochi
e3bdfc1f8e update readme for new events 2022-04-10 20:18:18 +01:00
JB
60cd972b7f Merge pull request #68 from mochi-co/fix-store-retained
Fix Store Retained Messages
2022-04-10 19:55:34 +01:00
mochi
ee081d0abe Merge branch 'v1.2.0' of https://github.com/mochi-co/mqtt into fix-store-retained 2022-04-10 19:54:37 +01:00
JB
d97b4bb81d Merge pull request #67 from mochi-co/release-client-buffers
Release client buffers
2022-04-10 19:53:46 +01:00
mochi
94aeacf0cb only check final outcome due to races 2022-04-10 19:46:12 +01:00
mochi
5cb8a081a1 accept any error for invalid protocol due to races 2022-04-10 19:43:59 +01:00
mochi
fc00112e47 Check for protocol violation errors 2022-04-10 19:42:37 +01:00
mochi
bbc22fae5b Add comments 2022-04-10 19:30:26 +01:00
mochi
82cb75913d Remove unused code 2022-04-10 19:28:22 +01:00
mochi
45c4a64b87 Abandon client state if the existing client specified a cleansession 2022-04-10 19:07:46 +01:00
mochi
bae2579497 Expose CleanSession value for checking 2022-04-10 19:07:15 +01:00
mochi
1bc01271cb Store retained message based on corrected r value 2022-04-10 18:47:45 +01:00
mochi
352a71f50c Expect correct r values for RetainMessage 2022-04-10 18:46:46 +01:00
mochi
6f9f62e38f Correctly return R value of retainMessage 2022-04-10 18:46:23 +01:00
mochi
0f67d9e8ff Update EstablishConnection tests to ensure buffers and pool are correctly released after use 2022-04-10 17:37:46 +01:00
mochi
f2dd5b63ae Export R/W buffer values so they can be assessed in tests without causing races 2022-04-10 17:37:25 +01:00
mochi
54e2d044a2 Track number of pool blocks in use 2022-04-10 17:36:46 +01:00
mochi
6298a87298 Use package errors instead of strings 2022-04-10 14:58:59 +01:00
mochi
b6fd25bba4 test clarbuffers 2022-04-10 01:46:18 +01:00
mochi
eef3592576 clear buffers after deferred stop 2022-04-10 01:46:04 +01:00
mochi
5d343c12e1 refactor clients for buffer releasing 2022-04-10 01:34:01 +01:00
mochi
70f52c8a3b Refactor establishconnection to prevent same-id disconnects 2022-04-10 01:33:35 +01:00
mochi
429b72265a refactor connSetup for clarity 2022-04-10 01:32:24 +01:00
mochi
f60d2dcfca Clarify error messages 2022-04-10 01:32:04 +01:00
mochi
6674cd64eb clarify error checking 2022-04-10 01:31:36 +01:00
mochi
f218cde69c Use defer to release buffers and decrease stats on any client closure 2022-04-09 20:53:28 +01:00
JB
9ea687eb94 Merge pull request #61 from mochi-co/server-options
Configurable Server Options
2022-04-09 20:29:23 +01:00
JB
949e4e2e91 Merge pull request #63 from mochi-co/add-drop-packet-error
Add ErrRejectPacket to OnProcessMessage
2022-04-09 20:29:03 +01:00
JB
515e0269de Merge pull request #62 from mochi-co/fix-inflight-key
Fix Inflight Persistence Key
2022-04-06 11:23:51 +01:00
mochi
ae6073c79c track logged error 2022-03-31 18:23:12 +01:00
mochi
b072a08f0b Update test to check for packet rejection 2022-03-31 18:04:41 +01:00
mochi
01d8a450d2 Ensure OnError is set before using it 2022-03-31 17:56:08 +01:00
mochi
da2fd41f79 Update OnProcessMessage documentation 2022-03-31 17:53:34 +01:00
mochi
56e8039093 Optionally drop a packet if the ErrRejectPacket error is returned from OnProcessMessage 2022-03-31 17:53:13 +01:00
mochi
7b9bc844c1 Add ErrRejectPacket error to abandon packet processing from OnMessageProcess 2022-03-31 17:52:40 +01:00
JB
8acb182820 Merge pull request #53 from stffabi/feature/onprocessmessage-event
Events: Add OnProcessMessage event
2022-03-31 17:41:22 +01:00
mochi
5726880095 fix inflight key reference 2022-03-31 17:36:03 +01:00
mochi
ee459e1b3d Fix code block formatting 2022-03-31 17:32:17 +01:00
mochi
6aec3a8bbf Update readme with server options 2022-03-31 17:21:42 +01:00
mochi
74699f0a87 Add example implementation 2022-03-31 17:21:35 +01:00
mochi
70def39ff9 add tests for new NewServer function 2022-03-31 17:00:00 +01:00
mochi
8e7098a32d remove deprecated log message 2022-03-31 16:59:49 +01:00
mochi
e4f02919fd Update code to use new NewServer function instead of deprecated New 2022-03-31 16:49:27 +01:00
mochi
99c96c844e Update example code to use new NewServer function instead of deprecated New 2022-03-31 16:49:02 +01:00
mochi
18629aea6d Use internal default values instead of relying on passed value 2022-03-31 16:48:42 +01:00
mochi
a0060429d1 Add Server Options
Adds a new struct of server options which can be used to override default properties. A new options-accepting NewServer function has been created to supersede the New method, which is now deprecated.
2022-03-31 16:48:29 +01:00
mochi
d946a9ae16 Update go mod to ensure bolt is using 1.3.5
Bolt 1.3.6 fails to build correctly and has been removed, so rollback to bolt 1.3.5. Also upgrade to Go 1.18
2022-03-31 16:39:28 +01:00
JB
51a2eb5f48 Merge pull request #57 from hybridgroup/v1.2.0-docker
docker: add initial simple Dockerfile
2022-03-30 09:37:53 +01:00
JB
0d4b0a89d8 Merge pull request #58 from soyoo/patch-1
typo
2022-03-30 09:30:11 +01:00
soyoo
0e7ccfe3fb typo 2022-03-25 16:55:57 +08:00
Ron Evans
5d7230630d docker: add initial simple Dockerfile 2022-03-23 10:37:20 +01:00
JB
6a3cbd6093 Merge pull request #51 from jmacd/jmacd/noracefix
Two no-functional-change cleanups combined
2022-03-22 16:19:24 +00:00
stffabi
7b4e79707b Events: Add OnProcessMessage event
This event gets called right after ACL checking but before any other
Fields of the packet get evaluated.
2022-03-21 10:11:29 +01:00
Joshua MacDonald
c6643592f6 Combines two fixes 2022-03-17 14:19:23 -07:00
mochi
5de12d0460 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2 2022-03-17 12:53:59 +00:00
JB
0a7205e110 Update README.md 2022-03-17 12:50:01 +00:00
JB
8133dd8299 Update README.md 2022-03-17 12:48:37 +00:00
JB
fdbfff57dc Merge pull request #46 from stffabi/bugfix/acls-retain
Subscribe: Only send retained messages if ACLs has allowed subscription to the topic
2022-03-17 09:41:48 +00:00
stffabi
f5fc5e8c44 Subscribe: Only send retained messages if ACLs has allowed subscription to the topic 2022-03-17 09:25:29 +01:00
mochi
9f44712b80 Fix incorrect test
The previous publish inline test incorrectly approved retain packets without retain=true fixedheader values.
2022-03-16 18:16:48 +00:00
stffabi
1f86168d9d Publish: Set the retain flag in the fixedheader (#42)
* Publish: Set the retain flag in the fixedheader
2022-03-16 18:12:07 +00:00
mochi
ab25083ed2 Merge branch 'master' of https://github.com/mochi-co/mqtt into v1.1.2
# Conflicts:
#	server/internal/clients/clients.go
2022-03-16 18:09:00 +00:00
JB
9b0aa4d559 Update README.md 2022-03-15 20:04:04 +00:00
JB
03814944a9 Update README.md 2022-03-15 19:58:58 +00:00
JB
3286d5a484 Replace Travis with Github Actions (#41)
* Remove Travis CI

* Add Github Actions Workflow

* Update badges for build status, coverage, report card, doc reference

* use actions for all pull requests and pushes

* test all files for coverage

* Apply gofmt -s to simplify code

* Fix typos

* Cleanup comments

* Cleanup comments

Co-authored-by: mochi <mochimou@icloud.com>
2022-03-15 19:56:42 +00:00
mochi
7e970d3c7a Fix typo 2022-03-15 19:13:24 +00:00
mochi
d6a92cc5bd Add Keyed fields to events.Client for readability and go vet 2022-03-15 18:44:49 +00:00
mochi
325d44d478 Add missing method comments 2022-03-15 18:44:21 +00:00
Joshua MacDonald
0a5f6d3a9d Add an OnError handler; report the reason for disconnects. (#38) 2022-03-15 17:59:52 +00:00
Joshua MacDonald
17253ad8bd Wrap packet errors with cause information (#39) 2022-03-15 17:34:49 +00:00
Joshua MacDonald
9f1c387091 Move two WaitGroup.Add calls (#36) 2022-03-15 17:33:31 +00:00
JB
9c6f602630 Merge pull request #29 from jmacd/jmacd/payload_not_utf8
Support non-UTF8 payloads (per MQTT specification)
2022-02-27 08:36:40 +00:00
Joshua MacDonald
b0dcaabdde Support non-UTF8 payloads per MQTT specification 2022-02-26 22:53:51 -08:00
JB
460f0ef681 revert redis update 2022-02-24 21:19:46 +00:00
JB
6e16765f60 revert server version 2022-02-24 21:19:22 +00:00
JB
2b361df19e Merge pull request #27 from mochi-co/revert-26-master
Revert "added redis persistence mode"
2022-02-24 21:15:08 +00:00
JB
c8c0a5a094 Revert "added redis persistence mode" 2022-02-24 21:10:39 +00:00
JB
4a833dd081 Update server version 2022-02-24 21:07:54 +00:00
JB
81198d9845 Update README.md 2022-02-24 21:07:24 +00:00
JB
6c12d8a71a Merge pull request #26 from wind-c/master
added redis persistence mode
2022-02-24 21:05:35 +00:00
narwal
19b598b672 redis and trie 2022-02-23 16:00:32 +08:00
narwal
b6529f05d3 add redis persistence mode and example 2022-02-22 18:57:31 +08:00
mochi
7f76445cc8 update server version 2022-01-30 10:39:49 +00:00
JB
b1c01792cd Merge pull request #24 from mochi-co/feature/optimise-struct-fields
Optimise Struct Fields + Fixes
2022-01-30 10:38:28 +00:00
mochi
eda03d4338 optimise Server struct 2022-01-30 10:30:34 +00:00
mochi
18070f1f57 pass byte pool by address 2022-01-30 10:30:19 +00:00
mochi
7f10c28a37 remove println 2022-01-30 10:30:01 +00:00
mochi
122531bb27 Pass inflight by address to avoid lock copying 2022-01-28 21:07:10 +00:00
mochi
e6dbcae428 Correct function signature 2022-01-28 21:06:57 +00:00
mochi
98875de568 Update test to match new FixedHeader struct 2022-01-28 21:06:43 +00:00
mochi
c9fd9451af Prevent locks from being copied 2022-01-28 21:06:24 +00:00
mochi
6550b8d680 8bit align struct fields 2022-01-28 21:05:50 +00:00
mochi
a60c96c889 Update comment for clarity 2022-01-28 21:04:15 +00:00
mochi
86e0a5827e Update version to 1.1.0 2022-01-26 20:49:53 +00:00
mochi
06c399b606 indicate ARM32 compatibility 2022-01-26 20:49:42 +00:00
JB
ed117f67a1 Merge pull request #22 from mochi-co/feature/32bit-compatibility
ARM32 Compatibility
2022-01-26 20:36:13 +00:00
JB
880a3299e1 Merge pull request #19 from rkennedy/bugfix/32-bit-atomic-alignment
Improve 32-bit compatibility
2022-01-26 08:02:57 +00:00
Rob Kennedy
1c408d05be Fix encodeLength for 32-bit platforms
When `int` is 32 bits, `MaxInt64` doesn't fit. It's apparent that
`encodeLength` expects to handle 64-bit inputs, so let's make that
explicit, which allows the test to run on all platforms.
2022-01-25 00:22:26 -06:00
Rob Kennedy
fce495f83e Avoid race condition when closing listeners
"Atomic load" followed by "atomic store" is not itself an atomic
operation. This commit replaces that sequence with CompareAndSwap
instead.
2022-01-25 00:22:26 -06:00
Rob Kennedy
471ca00a64 Make atomics work on 32-bit systems
On 32-bit systems, `atomic` requires its 64-bit arguments to have 64-bit
alignment, but the compiler doesn't help ensure that's the case. In this
commit, fields that don't need to hold large numbers have been converted
to 32-bit types, which are always aligned correctly on all platforms.
For fields that may hold large numeric values, padding has been added to
get the necessary alignment, and tests have been added to avoid
regressions.
2022-01-25 00:22:26 -06:00
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
mochi
fb136483d0 Revert server version 2022-01-10 23:50:40 +00:00
mochi
b209cd95f1 increment server version 2022-01-10 23:48:33 +00:00
mochi
3a7e58ec01 Remove unnecessary fmt import 2022-01-10 23:47:33 +00:00
mochi
a674632cce Increment server version 2022-01-10 23:41:46 +00:00
JB
09ddc412c7 Merge pull request #12 from jphastings/remove-erroneous-print 2022-01-10 23:38:33 +00:00
JP Hastings-Spital
6fbd8a5eb2 Remove unnecessary println 2022-01-10 23:36:33 +00:00
JB
d4ae73a97c fix indentation in code blocks
convert tabs to spaces
2022-01-05 21:43:47 +00:00
JB
3ff853a990 Update README.md 2022-01-05 21:41:45 +00:00
mochi
4302eed84f Update vendor 2022-01-05 21:28:00 +00:00
mochi
a1fee6ff68 Update go mod to 1.17 2022-01-05 21:27:52 +00:00
mochi
7fbc0b0187 fix code indents 2022-01-05 21:26:11 +00:00
mochi
8bbca347c4 Update go to 1.17 2022-01-05 21:21:35 +00:00
322 changed files with 28348 additions and 99156 deletions

43
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.17
- name: Vet
run: go vet ./...
- name: Test
run: go test -v -race ./...
coverage:
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.17.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cmd/mqtt
.DS_Store
.DS_Store
server/persistence/bolt/testbolt.db

View File

@@ -1,18 +0,0 @@
dist: xenial
language: go
env:
- GO111MODULE=on
go:
- 1.13.x
git:
depth: 1
script:
- go test -v -race ./... -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)

31
Dockerfile Normal file
View File

@@ -0,0 +1,31 @@
FROM golang:1.18.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

168
README.md
View File

@@ -1,10 +1,11 @@
<p align="center">
[![Build Status](https://travis-ci.com/mochi-co/mqtt.svg?token=59nqixhtefy2iQRwsPcu&branch=master)](https://travis-ci.com/mochi-co/mqtt)
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
[![codecov](https://codecov.io/gh/mochi-co/mqtt/branch/master/graph/badge.svg?token=6vBUgYVaVB)](https://codecov.io/gh/mochi-co/mqtt)
[![GoDoc](https://godoc.org/github.com/mochi-co/mqtt?status.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt)
</p>
@@ -13,6 +14,10 @@
Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with everyone's favourites such as Mosquitto, Mosca, and VerneMQ.
> #### 📦 💬 See Github Discussions for discussions about releases
> Ongoing discussion about current and future releases can be found at https://github.com/mochi-co/mqtt/discussions
#### What is MQTT?
MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq)
@@ -23,13 +28,14 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Ring Buffer packet codec.
- TCP, Websocket, (including SSL/TLS) and Dashboard listeners.
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
- Bolt persistence and storage interfaces (see examples folder).
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
- ARM32 Compatible.
#### Roadmap
- Event Hooks (eg. provide handler functions for `onMessage`).
- Docker Image
- MQTT v5 compatibility
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the `main.go` entrypoint in the `cmd` folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners. A docker image is coming soon.
@@ -43,28 +49,28 @@ go build -o mqtt && ./mqtt
``` go
import (
mqtt "github.com/mochi-co/mqtt/server"
mqtt "github.com/mochi-co/mqtt/server"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New()
server := mqtt.NewServer(nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883")
// Add the listener to the server with default options (nil).
err := server.AddListener(tcp, nil)
if err != nil {
log.Fatal(err)
}
if err != nil {
log.Fatal(err)
}
// Start the broker. Serve() is blocking - see examples folder
// for usage ideas.
// Start the broker. Serve() is blocking - see examples folder
// for usage ideas.
err = server.Serve()
if err != nil {
log.Fatal(err)
}
if err != nil {
log.Fatal(err)
}
}
```
@@ -83,9 +89,9 @@ When a listener is added to the server using `server.AddListener`, a `*listeners
Authentication and ACL may be configured on a per-listener basis by providing an Auth Controller to the listener configuration. Custom Auth Controllers should satisfy the `auth.Controller` interface found in `listeners/auth`. Two default controllers are provided, `auth.Allow`, which allows all traffic, and `auth.Disallow`, which denies all traffic.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
})
```
> If no auth controller is provided in the listener configuration, the server will default to _Disallowing_ all traffic to prevent unintentional security issues.
@@ -93,49 +99,107 @@ Authentication and ACL may be configured on a per-listener basis by providing an
##### SSL
SSL may be configured on both the TCP and Websocket listeners by providing a public-private PEM key pair to the listener configuration as `[]byte` slices.
```go
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: publicCertificate,
PrivateKey: privateKey,
},
})
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
TLS: &listeners.TLS{
Certificate: publicCertificate,
PrivateKey: privateKey,
},
})
```
> Note the mandatory inclusion of the Auth Controller!
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
if string(pk.Payload) == "hello" {
pkx = pk
pkx.Payload = []byte("hello world")
return pkx, nil
}
return pk, nil
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet (message) is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
##### OnProcessMessage
`server.Events.OnProcessMessage` is called before a publish packet (message) is processed. Specifically, the method callback is triggered after topic and ACL validation has occurred, but before the headers and payload are processed. You can use this if you want to programmatically change the data of the packet, such as setting it to retain, or altering the QoS flag.
If an error is returned, the packet will not be modified. and the existing packet will be used. If this is an unwanted outcome, the `mqtt.ErrRejectPacket` error can be returned from the callback, and the packet will be dropped/ignored, any further processing is abandoned.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
if string(pk.Payload) == "hello" {
pkx = pk
pkx.Payload = []byte("hello world")
return pkx, nil
}
return pk, nil
}
```
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
##### OnError
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
##### OnStorage
`server.Events.OnStorage` is like `onError`, but receives the output of persistent storage methods.
#### Server Options
A few options can be passed to the `mqtt.NewServer(opts *Options)` function in order to override the default broker configuration. Currently these options are:
- BufferSize (default 1024 * 256 bytes) - The default value is sufficient for most messaging sizes, but if you are sending many kilobytes of data (such as images), you should increase this to a value of (n*s) where is the typical size of your message and n is the number of messages you may have backlogged for a client at any given time.
- BufferBlockSize (default 1024 * 8) - The minimum size in which R/W data will be allocated. If you are expecting only tiny or large payloads, you can alter this accordingly.
Any options which is not set or is `0` will use default values.
```go
opts := &mqtt.Options{
BufferSize: 512 * 1024,
BufferBlockSize: 16 * 1024,
}
s := mqtt.NewServer(opts)
```
> See `examples/tcp/main.go` for an example implementation.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
```go
// func (s *Server) Publish(topic string, payload []byte, retain bool) error
err := s.Publish("a/b/c", []byte("hello"), false)
if err != nil {
log.Fatal(err)
}
// func (s *Server) Publish(topic string, payload []byte, retain bool) error
err := s.Publish("a/b/c", []byte("hello"), false)
if err != nil {
log.Fatal(err)
}
```
A working example can be found in the `examples/events` folder.
@@ -143,11 +207,11 @@ A working example can be found in the `examples/events` folder.
#### Data Persistence
Mochi MQTT provides a `persistence.Store` interface for developing and attaching persistent stores to the broker. The default persistence mechanism packaged with the broker is backed by [Bolt](https://github.com/etcd-io/bbolt) and can be enabled by assigning a `*bolt.Store` to the server.
```go
// import "github.com/mochi-co/mqtt/server/persistence/bolt"
err = server.AddStore(bolt.New("mochi.db", nil))
if err != nil {
log.Fatal(err)
}
// import "github.com/mochi-co/mqtt/server/persistence/bolt"
err = server.AddStore(bolt.New("mochi.db", nil))
if err != nil {
log.Fatal(err)
}
```
> Persistence is on-demand (not flushed) and will potentially reduce throughput when compared to the standard in-memory store. Only use it if you need to maintain state through restarts.
@@ -155,7 +219,7 @@ Mochi MQTT provides a `persistence.Store` interface for developing and attaching
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder.
#### Performance (messages/second)
#### Performance at v1.0.0
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a 13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. SEND = Publish throughput, RECV = Subscribe throughput.
> As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers.

View File

@@ -33,7 +33,7 @@ func main() {
fmt.Println(aurora.Cyan("Websocket"), *wsAddr)
fmt.Println(aurora.Cyan("$SYS Dashboard"), *infoAddr)
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", *tcpAddr)
err := server.AddListener(tcp, nil)
if err != nil {

View File

@@ -24,7 +24,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
stats := listeners.NewHTTPStats("stats", ":8080")
err := server.AddListener(stats, nil)

View File

@@ -27,7 +27,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),
@@ -44,6 +44,16 @@ func main() {
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
@@ -54,6 +64,12 @@ func main() {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}

View File

@@ -46,7 +46,7 @@ func main() {
// Auth is an example auth provider for the server.
type Auth struct{}
// Auth returns true if a username and password are acceptable.
// Authenticate returns true if a username and password are acceptable.
// Auth always returns true.
func (a *Auth) Authenticate(user, password []byte) bool {
return true
@@ -55,8 +55,5 @@ func (a *Auth) Authenticate(user, password []byte) bool {
// ACL returns true if a user has access permissions to read or write on a topic.
// ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure.
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
if topic == "test/nosubscribe" {
return false
}
return true
return topic != "test/nosubscribe"
}

View File

@@ -28,7 +28,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -25,7 +25,13 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
// An example of configuring various server options...
options := &mqtt.Options{
BufferSize: 0, // Use default values
BufferBlockSize: 0, // Use default values
}
server := mqtt.NewServer(options)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -57,7 +57,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
server := mqtt.New()
server := mqtt.NewServer(nil)
tcp := listeners.NewTCP("t1", ":1883")
err := server.AddListener(tcp, &listeners.Config{
Auth: new(auth.Allow),

View File

@@ -24,7 +24,7 @@ func main() {
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
server := mqtt.New()
server := mqtt.NewServer(nil)
ws := listeners.NewWebsocket("ws1", ":1882")
err := server.AddListener(ws, nil)
if err != nil {

24
go.mod
View File

@@ -1,15 +1,21 @@
module github.com/mochi-co/mqtt
go 1.13
go 1.18
require (
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.1.0
github.com/gorilla/websocket v1.4.1
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
github.com/rs/xid v1.2.1
github.com/stretchr/testify v1.4.0
go.etcd.io/bbolt v1.3.3
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
github.com/asdine/storm/v3 v3.2.1
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/rs/xid v1.4.0
github.com/stretchr/testify v1.7.1
go.etcd.io/bbolt v1.3.5
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

43
go.sum
View File

@@ -4,9 +4,8 @@ github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwv
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
github.com/asdine/storm/v3 v3.1.0 h1:yrpSNS+E7ef5Y5KjyZDeyW72Dl17lYG7oZ7eUoWvo5s=
github.com/asdine/storm/v3 v3.1.0/go.mod h1:letAoLCXz4UfodwNgMNILMb2oRH+su337ZfHnkRzqDA=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -15,45 +14,45 @@ github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a h1:zPPuIq2jAWWPTrGt70eK/BSch+gFAGrNzecsoENgu2o=
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23 h1:Wp7NjqGKGN9te9N/rvXYRhlVcrulGdxnz8zadXWs7fc=
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8=
github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179 h1:IqVhUQp5B9ARnZUcfqXy6zP+A+YuPpP7IFo8gFeCOzU=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d h1:L/IKR6COd7ubZrs2oTnTi73IhgqJ71c9s80WsQnh0Es=
golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,29 +1,49 @@
package events
import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
)
// Events provides callback handlers for different event hooks.
type Events struct {
OnMessage // published message receieved.
OnProcessMessage // published message receieved before evaluation.
OnMessage // published message receieved.
OnError // server error.
OnConnect // client connected.
OnDisconnect // client disconnected.
}
// Packets is an alias for packets.Packet.
type Packet packets.Packet
// Client contains limited information about a connected client.
type Client struct {
ID string
Remote string
Listener string
}
// FromClient returns an event client from a client.
func FromClient(cl clients.Client) Client {
return Client{
ID: cl.ID,
Listener: cl.Listener,
}
// Clientlike is an interface for Clients and client-like objects that
// are able to describe their client/listener IDs and remote address.
type Clientlike interface {
Info() Client
}
// OnProcessMessage is called when a publish message is received, allowing modification
// of the packet data after ACL checking has occurred but before any data is evaluated
// for processing - e.g. for changing the Retain flag. Note, this hook is ONLY called
// by connected client publishers, it is not triggered when using the direct
// s.Publish method. The function receives the sent message and the
// data of the client who published it, and allows the packet to be modified
// before it is dispatched to subscribers. If no modification is required, return
// the original packet data. If an error occurs, the original packet will
// be dispatched as if the event hook had not been triggered.
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
// processing of the packet.
type OnProcessMessage func(Client, Packet) (Packet, error)
// OnMessage function is called when a publish message is received. Note,
// this hook is ONLY called by connected client publishers, it is not triggered when
// using the direct s.Publish method. The function receives the sent message and the
@@ -34,3 +54,15 @@ func FromClient(cl clients.Client) Client {
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)
// OnError is called when errors that will not be passed to
// OnDisconnect are handled by the server.
type OnError func(Client, error)

View File

@@ -8,33 +8,39 @@ import (
)
var (
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
// DefaultBufferSize is the default size of the buffer in bytes.
DefaultBufferSize int = 1024 * 256
ErrOutOfRange = errors.New("Indexes out of range")
// DefaultBlockSize is the default size per R/W block in bytes.
DefaultBlockSize int = 1024 * 8
// ErrOutOfRange indicates that the index was out of range.
ErrOutOfRange = errors.New("Indexes out of range")
// ErrInsufficientBytes indicates that there were not enough bytes to return.
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// buffer contains core values and methods to be included in a reader or writer.
// Buffer is a circular buffer for reading and writing messages.
type Buffer struct {
Mu sync.RWMutex // the buffer needs it's own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
done int64 // indicates that the buffer is closed.
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) Buffer {
func NewBuffer(size, block int) *Buffer {
if size == 0 {
size = DefaultBufferSize
}
@@ -47,7 +53,7 @@ func NewBuffer(size, block int) Buffer {
size = 2 * block
}
return Buffer{
return &Buffer{
size: size,
mask: size - 1,
block: block,
@@ -59,14 +65,14 @@ func NewBuffer(size, block int) Buffer {
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) Buffer {
func NewBufferFromSlice(block int, buf []byte) *Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := Buffer{
b := &Buffer{
size: l,
mask: l - 1,
block: block,
@@ -78,7 +84,7 @@ func NewBufferFromSlice(block int, buf []byte) Buffer {
return b
}
// Get will return the tail and head positions of the buffer.
// GetPos will return the tail and head positions of the buffer.
// This method is for use with testing.
func (b *Buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
@@ -127,7 +133,7 @@ func (b *Buffer) awaitEmpty(n int) error {
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
@@ -146,7 +152,7 @@ func (b *Buffer) awaitFilled(n int) error {
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
@@ -196,7 +202,7 @@ func (b *Buffer) CapDelta() int {
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreInt64(&b.done, 1)
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()

View File

@@ -5,6 +5,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, int64(1), buf.done)
require.Equal(t, uint32(1), buf.done)
}

View File

@@ -2,17 +2,23 @@ package circ
import (
"sync"
"sync/atomic"
)
// BytesPool is a pool of []byte.
type BytesPool struct {
pool sync.Pool
pool *sync.Pool
used int64
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) BytesPool {
return BytesPool{
pool: sync.Pool{
func NewBytesPool(n int) *BytesPool {
if n == 0 {
n = DefaultBufferSize
}
return &BytesPool{
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, n)
},
@@ -22,6 +28,7 @@ func NewBytesPool(n int) BytesPool {
// Get returns a pooled bytes.Buffer.
func (b *BytesPool) Get() []byte {
atomic.AddInt64(&b.used, 1)
return b.pool.Get().([]byte)
}
@@ -31,4 +38,10 @@ func (b *BytesPool) Put(x []byte) {
x[i] = 0
}
b.pool.Put(x)
atomic.AddInt64(&b.used, -1)
}
// InUse returns the number of pool blocks in use.
func (b *BytesPool) InUse() int64 {
return atomic.LoadInt64(&b.used)
}

View File

@@ -22,6 +22,7 @@ func TestNewBytesPoolGet(t *testing.T) {
buf := bpool.Get()
require.Equal(t, make([]byte, 256), buf)
require.Equal(t, int64(1), bpool.InUse())
}
func BenchmarkBytesPoolGet(b *testing.B) {
@@ -34,7 +35,9 @@ func BenchmarkBytesPoolGet(b *testing.B) {
func TestNewBytesPoolPut(t *testing.T) {
bpool := NewBytesPool(256)
buf := bpool.Get()
require.Equal(t, int64(1), bpool.InUse())
bpool.Put(buf)
require.Equal(t, int64(0), bpool.InUse())
}
func BenchmarkBytesPoolPut(b *testing.B) {

View File

@@ -7,7 +7,7 @@ import (
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
Buffer
*Buffer
}
// NewReader returns a new Circular Reader.
@@ -19,7 +19,7 @@ func NewReader(size, block int) *Reader {
}
}
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
// byte slice.
func NewReaderFromSlice(block int, p []byte) *Reader {
b := NewBufferFromSlice(block, p)
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreInt64(&b.State, 1)
defer atomic.StoreInt64(&b.State, 0)
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}
@@ -60,7 +60,7 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
return total, nil
return total, err
}
// Move the head forward however many bytes were read.

View File

@@ -2,6 +2,8 @@ package circ
import (
"bytes"
"errors"
"io"
"sync/atomic"
"testing"
"time"
@@ -34,16 +36,18 @@ func TestReadFrom(t *testing.T) {
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.NoError(t, err)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.True(t, errors.Is(err, io.EOF))
require.Equal(t, int64(12), buf.head)
}
@@ -60,7 +64,7 @@ func TestReadFromWrap(t *testing.T) {
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -116,7 +120,7 @@ func TestReadEnded(t *testing.T) {
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -1,14 +1,14 @@
package circ
import (
"fmt"
"io"
"log"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
Buffer
*Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
@@ -20,7 +20,7 @@ func NewWriter(size, block int) *Writer {
}
}
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
// byte slice.
func NewWriterFromSlice(block int, p []byte) *Writer {
b := NewBufferFromSlice(block, p)
@@ -31,11 +31,11 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2)
defer atomic.StoreInt64(&b.State, 0)
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
@@ -59,14 +59,12 @@ func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
p = append(p, b.buf[rTail:rHead]...)
}
//fmt.Println("writing", p)
n, err = w.Write(p)
total += n
total += int64(n)
if err != nil {
fmt.Println("writing err", err)
log.Println("error writing to buffer io.Writer;", err)
return
}
//fmt.Println("written", n)
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))

View File

@@ -52,14 +52,14 @@ func TestWriteTo(t *testing.T) {
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int)
nc := make(chan int64)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -12,6 +12,7 @@ import (
"github.com/rs/xid"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
@@ -23,7 +24,9 @@ var (
// defaultKeepalive is the default connection keepalive value in seconds.
defaultKeepalive uint16 = 10
ErrConnectionClosed = errors.New("Connection not open")
// ErrConnectionClosed is returned when operating on a closed
// connection and/or when no error cause has been given.
ErrConnectionClosed = errors.New("connection not open")
)
// Clients contains a map of the clients known by the broker.
@@ -74,7 +77,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
@@ -84,42 +87,43 @@ func (cl *Clients) GetByListener(id string) []*Client {
// Client contains information about a client known by the broker.
type Client struct {
sync.RWMutex
conn net.Conn // the net.Conn used to establish the connection.
r *circ.Reader // a reader for reading incoming bytes.
w *circ.Writer // a writer for writing outgoing bytes.
ID string // the client id.
AC auth.Controller // an auth controller inherited from the listener.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
Listener string // the id of the listener the client is connected to.
Inflight Inflight // a map of in-flight qos messages.
Username []byte // the username the client authenticated with.
keepalive uint16 // the number of seconds the connection can wait.
cleanSession bool // indicates if the client expects a clean-session.
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
LWT LWT // the last will and testament for the client.
Inflight *Inflight // a map of in-flight qos messages.
sync.RWMutex // mutex
Username []byte // the username the client authenticated with.
AC auth.Controller // an auth controller inherited from the listener.
Listener string // the id of the listener the client is connected to.
ID string // the client id.
conn net.Conn // the net.Conn used to establish the connection.
R *circ.Reader // a reader for reading incoming bytes.
W *circ.Writer // a writer for writing outgoing bytes.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
systemInfo *system.Info // pointers to server system info.
packetID uint32 // the current highest packetID.
keepalive uint16 // the number of seconds the connection can wait.
CleanSession bool // indicates if the client expects a clean-session.
}
// State tracks the state of the client.
type State struct {
Done int64 // atomic counter which indicates that the client has closed.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
endOnce sync.Once // only end once.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
Done uint32 // atomic counter which indicates that the client has closed.
endOnce sync.Once // only end once.
stopCause atomic.Value // reason for stopping.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
conn: c,
R: r,
W: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -139,7 +143,7 @@ func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Clie
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: Inflight{
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -159,11 +163,11 @@ func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
cl.ID = xid.New().String()
}
cl.r.ID = cl.ID + " READER"
cl.w.ID = cl.ID + " WRITER"
cl.R.ID = cl.ID + " READER"
cl.W.ID = cl.ID + " WRITER"
cl.Username = pk.Username
cl.cleanSession = pk.CleanSession
cl.CleanSession = pk.CleanSession
cl.keepalive = pk.Keepalive
if pk.WillFlag {
@@ -185,7 +189,20 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
}
cl.conn.SetDeadline(expiry)
_ = cl.conn.SetDeadline(expiry)
}
}
// Info returns an event-version of a client, containing minimal information.
func (cl *Client) Info() events.Client {
addr := "unknown"
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
addr = cl.conn.RemoteAddr().String()
}
return events.Client{
ID: cl.ID,
Remote: addr,
Listener: cl.Listener,
}
}
@@ -218,47 +235,75 @@ func (cl *Client) ForgetSubscription(filter string) {
// Start begins the client goroutines reading and writing packets.
func (cl *Client) Start() {
cl.State.started.Add(2)
go func() {
cl.State.started.Done()
cl.w.WriteTo(cl.conn)
cl.State.endedW.Done()
cl.Stop()
}()
cl.State.endedW.Add(1)
cl.State.endedR.Add(1)
go func() {
cl.State.started.Done()
cl.r.ReadFrom(cl.conn)
cl.State.endedR.Done()
cl.Stop()
_, err := cl.W.WriteTo(cl.conn)
if err != nil {
err = fmt.Errorf("writer: %w", err)
}
cl.State.endedW.Done()
cl.Stop(err)
}()
go func() {
cl.State.started.Done()
_, err := cl.R.ReadFrom(cl.conn)
if err != nil {
err = fmt.Errorf("reader: %w", err)
}
cl.State.endedR.Done()
cl.Stop(err)
}()
cl.State.endedR.Add(1)
cl.State.started.Wait()
}
// ClearBuffers sets the read/write buffers to nil so they can be
// deallocated automatically when no longer in use.
func (cl *Client) ClearBuffers() {
cl.R = nil
cl.W = nil
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
// A cause error may be passed to identfy the reason for stopping.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
cl.State.endOnce.Do(func() {
cl.r.Stop()
cl.w.Stop()
cl.R.Stop()
cl.W.Stop()
cl.State.endedW.Wait()
cl.conn.Close()
_ = cl.conn.Close() // omit close error
cl.State.endedR.Wait()
atomic.StoreInt64(&cl.State.Done, 1)
atomic.StoreUint32(&cl.State.Done, 1)
if err == nil {
err = ErrConnectionClosed
}
cl.State.stopCause.Store(err)
})
}
// readFixedHeader reads in the values of the next packet's fixed header.
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
p, err := cl.r.Read(1)
p, err := cl.R.Read(1)
if err != nil {
return err
}
@@ -275,7 +320,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
i := 1
n := 2
for ; n < 6; n++ {
p, err = cl.r.Read(n)
p, err = cl.R.Read(n)
if err != nil {
return err
}
@@ -299,8 +344,8 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
fh.Remaining = int(rem)
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
cl.R.CommitTail(n)
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
@@ -309,7 +354,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
return nil
}
@@ -334,18 +379,18 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
return
}
p, err := cl.r.Read(pk.FixedHeader.Remaining)
p, err := cl.R.Read(pk.FixedHeader.Remaining)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
@@ -359,7 +404,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
@@ -384,19 +429,19 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
}
cl.r.CommitTail(pk.FixedHeader.Remaining)
cl.R.CommitTail(pk.FixedHeader.Remaining)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
cl.w.Mu.Lock()
defer cl.w.Mu.Unlock()
cl.W.Mu.Lock()
defer cl.W.Mu.Unlock()
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
@@ -407,7 +452,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
@@ -439,12 +484,13 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
}
// Write the packet bytes to the client byte buffer.
n, err = cl.w.Write(buf.Bytes())
n, err = cl.W.Write(buf.Bytes())
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
@@ -453,8 +499,8 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/mochi-co/mqtt/server/events"
"github.com/mochi-co/mqtt/server/internal/circ"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/listeners/auth"
@@ -16,6 +17,8 @@ import (
"github.com/stretchr/testify/require"
)
var testClientStop = errors.New("test stop")
func genClient() *Client {
c, _ := net.Pipe()
return NewClient(c, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
@@ -130,11 +133,39 @@ func TestNewClient(t *testing.T) {
require.NotNil(t, cl)
require.NotNil(t, cl.Inflight.internal)
require.NotNil(t, cl.Subscriptions)
require.NotNil(t, cl.r)
require.NotNil(t, cl.w)
require.NotNil(t, cl.State.started)
require.NotNil(t, cl.State.endedW)
require.NotNil(t, cl.State.endedR)
require.NotNil(t, cl.R)
require.NotNil(t, cl.W)
require.Nil(t, cl.StopCause())
}
func TestClientInfoUnknown(t *testing.T) {
cl := genClient()
cl.ID = "testid"
cl.Listener = "testlistener"
cl.conn = nil
require.Equal(t, events.Client{
ID: "testid",
Remote: "unknown",
Listener: "testlistener",
}, cl.Info())
}
func TestClientInfoKnown(t *testing.T) {
c1, c2 := net.Pipe()
defer c1.Close()
defer c2.Close()
cl := genClient()
cl.ID = "ID"
cl.Listener = "L"
cl.conn = c1
require.Equal(t, events.Client{
ID: "ID",
Remote: c1.RemoteAddr().String(),
Listener: "L",
}, cl.Info())
}
func BenchmarkNewClient(b *testing.B) {
@@ -175,7 +206,7 @@ func TestClientIdentify(t *testing.T) {
cl.Identify("tcp1", pk, new(auth.Allow))
require.Equal(t, pk.Keepalive, cl.keepalive)
require.Equal(t, pk.CleanSession, cl.cleanSession)
require.Equal(t, pk.CleanSession, cl.CleanSession)
require.Equal(t, pk.ClientIdentifier, cl.ID)
}
@@ -314,15 +345,15 @@ func BenchmarkClientRefreshDeadline(b *testing.B) {
func TestClientStart(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
time.Sleep(time.Millisecond)
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.R.State))
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.W.State))
}
func BenchmarkClientStart(b *testing.B) {
cl := genClient()
defer cl.Stop()
defer cl.Stop(testClientStop)
for n := 0; n < b.N; n++ {
cl.Start()
@@ -332,17 +363,17 @@ func BenchmarkClientStart(b *testing.B) {
func TestClientReadFixedHeader(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.r.SetPos(0, 2)
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.R.SetPos(0, 2)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
tail, head := cl.r.GetPos()
tail, head := cl.R.GetPos()
require.Equal(t, int64(2), tail)
require.Equal(t, int64(2), head)
@@ -351,13 +382,13 @@ func TestClientReadFixedHeader(t *testing.T) {
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
cl.r.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
cl.r.SetPos(0, 2)
cl.R.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
cl.R.SetPos(0, 2)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
@@ -367,17 +398,17 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.r.SetPos(0, 1)
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
cl.R.SetPos(0, 1)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
cl.r.Stop()
cl.R.Stop()
err := <-o
require.Error(t, err)
require.Equal(t, io.EOF, err)
@@ -386,14 +417,14 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
o := make(chan error)
go func() {
fh := new(packets.FixedHeader)
err := cl.r.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
err := cl.R.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
require.NoError(t, err)
cl.r.SetPos(0, 5)
cl.R.SetPos(0, 5)
o <- cl.ReadFixedHeader(fh)
}()
time.Sleep(time.Millisecond)
@@ -403,7 +434,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
func TestClientReadOK(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
// Two packets in a row...
b := []byte{
@@ -417,9 +448,9 @@ func TestClientReadOK(t *testing.T) {
'y', 'e', 'a', 'h', // Payload
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
var pks []packets.Packet
@@ -431,7 +462,7 @@ func TestClientReadOK(t *testing.T) {
}()
time.Sleep(time.Millisecond)
cl.r.Stop()
cl.R.Stop()
err = <-o
require.Error(t, err)
@@ -456,15 +487,25 @@ func TestClientReadOK(t *testing.T) {
},
})
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
}
func TestClientClearBuffers(t *testing.T) {
cl := genClient()
cl.Start()
cl.Stop(testClientStop)
cl.ClearBuffers()
require.Nil(t, cl.W)
require.Nil(t, cl.R)
}
func TestClientReadDone(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
cl.State.Done = 1
err := cl.Read(func(cl *Client, pk packets.Packet) error {
@@ -477,7 +518,7 @@ func TestClientReadDone(t *testing.T) {
func TestClientReadPacketError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
b := []byte{
0, 18,
@@ -485,9 +526,9 @@ func TestClientReadPacketError(t *testing.T) {
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i',
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
go func() {
@@ -499,10 +540,37 @@ func TestClientReadPacketError(t *testing.T) {
require.Error(t, <-o)
}
func TestClientReadPacketEOF(t *testing.T) {
cl := genClient()
cl.Start()
b := []byte{
0, 18,
0, 5,
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', // missing 1 byte
}
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.R.SetPos(0, int64(len(b)))
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
cl.R.Stop()
cl.Stop(testClientStop)
require.Error(t, <-o)
require.True(t, errors.Is(cl.StopCause(), testClientStop))
}
func TestClientReadHandlerErr(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
b := []byte{
byte(packets.Publish << 4), 11, // Fixed header
@@ -511,9 +579,9 @@ func TestClientReadHandlerErr(t *testing.T) {
'y', 'e', 'a', 'h', // Payload
}
err := cl.r.Set(b, 0, len(b))
err := cl.R.Set(b, 0, len(b))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(b)))
cl.R.SetPos(0, int64(len(b)))
err = cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
@@ -525,16 +593,16 @@ func TestClientReadHandlerErr(t *testing.T) {
func TestClientReadPacketOK(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
err := cl.r.Set([]byte{
err := cl.R.Set([]byte{
byte(packets.Publish << 4), 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
}, 0, 13)
require.NoError(t, err)
cl.r.SetPos(0, 13)
cl.R.SetPos(0, 13)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
@@ -557,12 +625,12 @@ func TestClientReadPacketOK(t *testing.T) {
func TestClientReadPacket(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
for i, tt := range pkTable {
err := cl.r.Set(tt.bytes, 0, len(tt.bytes))
err := cl.R.Set(tt.bytes, 0, len(tt.bytes))
require.NoError(t, err)
cl.r.SetPos(0, int64(len(tt.bytes)))
cl.R.SetPos(0, int64(len(tt.bytes)))
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
@@ -574,7 +642,7 @@ func TestClientReadPacket(t *testing.T) {
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
}
}
}
@@ -582,16 +650,16 @@ func TestClientReadPacket(t *testing.T) {
func TestClientReadPacketReadingError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
defer cl.Stop(testClientStop)
err := cl.r.Set([]byte{
err := cl.R.Set([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
}, 0, 13)
require.NoError(t, err)
cl.r.SetPos(2, 13)
cl.R.SetPos(2, 13)
_, err = cl.ReadPacket(&packets.FixedHeader{
Type: 0,
@@ -603,8 +671,8 @@ func TestClientReadPacketReadingError(t *testing.T) {
func TestClientReadPacketReadError(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
cl.r.Stop()
defer cl.Stop(testClientStop)
cl.R.Stop()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
@@ -616,8 +684,8 @@ func TestClientReadPacketReadError(t *testing.T) {
func TestClientReadPacketReadUnknown(t *testing.T) {
cl := genClient()
cl.Start()
defer cl.Stop()
cl.r.Stop()
defer cl.Stop(testClientStop)
cl.R.Stop()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
@@ -646,11 +714,21 @@ func TestClientWritePacket(t *testing.T) {
r.Close()
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop()
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
cl.Stop(testClientStop)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
time.Sleep(time.Millisecond * 5)
require.True(t,
errors.Is(err, testClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
}
}
}
@@ -658,8 +736,8 @@ func TestClientWritePacket(t *testing.T) {
func TestClientWritePacketWriteNoConn(t *testing.T) {
c, _ := net.Pipe()
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
cl.w.SetPos(0, 16)
cl.Stop()
cl.W.SetPos(0, 16)
cl.Stop(testClientStop)
_, err := cl.WritePacket(pkTable[1].packet)
require.Error(t, err)
@@ -669,8 +747,8 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
func TestClientWritePacketWriteError(t *testing.T) {
c, _ := net.Pipe()
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
cl.w.SetPos(0, 16)
cl.w.Stop()
cl.W.SetPos(0, 16)
cl.W.Stop()
_, err := cl.WritePacket(pkTable[1].packet)
require.Error(t, err)
@@ -729,7 +807,7 @@ func TestInflightGetAll(t *testing.T) {
m := cl.Inflight.GetAll()
o := map[uint16]InflightMessage{
2: InflightMessage{},
2: {},
}
require.Equal(t, o, m)
}

View File

@@ -28,6 +28,10 @@ func decodeString(buf []byte, offset int) (string, int, error) {
return "", 0, err
}
if !validUTF8(b) {
return "", 0, ErrOffsetStrInvalidUTF8
}
return bytesToString(b), n, nil
}
@@ -39,12 +43,10 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
}
if next+int(length) > len(buf) {
return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
}
if !validUTF8(buf[next : next+int(length)]) {
return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8
}
// Note: there is no validUTF8() test for []byte payloads
return buf[next : next+int(length)], next + int(length), nil
}

View File

@@ -1,6 +1,8 @@
package packets
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/require"
@@ -19,15 +21,16 @@ func BenchmarkBytesToString(b *testing.B) {
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result []string
result string
offset int
shouldFail bool
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: []string{"a/b/c/d", "a"},
result: "a/b/c/d",
},
{
offset: 14,
@@ -41,36 +44,32 @@ func TestDecodeString(t *testing.T) {
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: []string{"hey"},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: []string{"x/y/z", "!@#$%^&"},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
result: []string{"a/b/c/d", "z"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
result: []string{"a/b/c/d", "x"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
result: []string{"a/b/c/d", "y"},
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
{
offset: 17,
@@ -86,20 +85,26 @@ func TestDecodeString(t *testing.T) {
0, 6, // Will Topic - MSB+LSB
'l',
},
result: []string{"lwt"},
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrOffsetStrInvalidUTF8,
},
}
for i, wanted := range expect {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding string [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding string [i:%d]", i)
require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i)
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
@@ -116,7 +121,7 @@ func TestDecodeBytes(t *testing.T) {
result []uint8
next int
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
@@ -132,33 +137,27 @@ func TestDecodeBytes(t *testing.T) {
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
shouldFail: ErrOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 8,
shouldFail: true,
},
{
rawBytes: []byte{0, 4, 77, 81},
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
offset: 0,
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding bytes [i:%d]", i)
continue
}
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding bytes [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
@@ -174,7 +173,7 @@ func TestDecodeByte(t *testing.T) {
rawBytes []byte
result uint8
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
@@ -198,22 +197,23 @@ func TestDecodeByte(t *testing.T) {
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
result: uint8(0x00),
offset: 8,
shouldFail: true,
shouldFail: ErrOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte [i:%d]", i)
continue
}
require.NoError(t, err, "Error decoding byte [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i)
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
@@ -229,7 +229,7 @@ func TestDecodeUint16(t *testing.T) {
rawBytes []byte
result uint16
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
@@ -243,22 +243,24 @@ func TestDecodeUint16(t *testing.T) {
},
{
rawBytes: []byte{0, 7, 255, 47},
result: uint16(0x761),
offset: 8,
shouldFail: true,
shouldFail: ErrOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding uint16 [i:%d]", i)
continue
}
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding uint16 [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
@@ -274,7 +276,7 @@ func TestDecodeByteBool(t *testing.T) {
rawBytes []byte
result bool
offset int
shouldFail bool
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
@@ -287,20 +289,22 @@ func TestDecodeByteBool(t *testing.T) {
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: true,
shouldFail: ErrOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail {
require.Error(t, err, "Expected error decoding byte bool [i:%d]", i)
continue
}
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err, "Error decoding byte bool [i:%d]", i)
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i)
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}

View File

@@ -6,20 +6,20 @@ import (
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
encodeLength(buf, int64(fh.Remaining))
}
// decode extracts the specification bits from the header byte.
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(headerByte byte) error {
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int) {
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128

View File

@@ -18,130 +18,130 @@ type fixedHeaderTable struct {
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Connack, false, 0, false, 0},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Publish, false, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Publish, false, 1, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Publish, false, 2, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
flagError: true,
},
}
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int
have int64
want []byte
}{
{

View File

@@ -3,6 +3,8 @@ package packets
import (
"bytes"
"errors"
"fmt"
"strconv"
)
// All of the valid packet types and their packet identifier.
@@ -60,7 +62,6 @@ var (
// PACKETS
ErrProtocolViolation = errors.New("protocol violation")
ErrOffsetStrOutOfRange = errors.New("offset string out of range")
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
@@ -76,40 +77,31 @@ var (
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
PacketID uint16
// Connect
FixedHeader FixedHeader
AllowClients []string // For use with OnMessage event hook.
Topics []string
ReturnCodes []byte
ProtocolName []byte
Qoss []byte
Payload []byte
Username []byte
Password []byte
WillMessage []byte
ClientIdentifier string
TopicName string
WillTopic string
PacketID uint16
Keepalive uint16
ReturnCode byte
ProtocolVersion byte
WillQos byte
ReservedBit byte
CleanSession bool
WillFlag bool
WillQos byte
WillRetain bool
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
Keepalive uint16
ClientIdentifier string
WillTopic string
WillMessage []byte
Username []byte
Password []byte
// Connack
SessionPresent bool
ReturnCode byte
// Publish
TopicName string
Payload []byte
// Subscribe, Unsubscribe
Topics []string
Qoss []byte
ReturnCodes []byte // Suback
SessionPresent bool
}
// ConnectEncode encodes a connect packet.
@@ -169,17 +161,17 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
// Unpack protocol name and version.
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
if err != nil {
return ErrMalformedProtocolName
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
}
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedProtocolVersion
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
}
// Unpack flags byte.
flags, offset, err := decodeByte(buf, offset)
if err != nil {
return ErrMalformedFlags
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
}
pk.ReservedBit = 1 & flags
pk.CleanSession = 1&(flags>>1) > 0
@@ -192,25 +184,25 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
// Get keepalive interval.
pk.Keepalive, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedKeepalive
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
}
// Get client ID.
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedClientID
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
}
// Get Last Will and Testament topic and message if applicable.
if pk.WillFlag {
pk.WillTopic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedWillTopic
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
}
pk.WillMessage, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedWillMessage
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
}
}
@@ -218,14 +210,14 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.UsernameFlag {
pk.Username, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedUsername
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
}
}
if pk.PasswordFlag {
pk.Password, offset, err = decodeBytes(buf, offset)
if err != nil {
return ErrMalformedPassword
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
}
}
@@ -292,12 +284,12 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
if err != nil {
return ErrMalformedSessionPresent
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
}
pk.ReturnCode, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedReturnCode
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
}
return nil
@@ -334,7 +326,7 @@ func (pk *Packet) PubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -352,7 +344,7 @@ func (pk *Packet) PubcompDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -390,14 +382,14 @@ func (pk *Packet) PublishDecode(buf []byte) error {
pk.TopicName, offset, err = decodeString(buf, 0)
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
// If QOS decode Packet ID.
if pk.FixedHeader.Qos > 0 {
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
}
@@ -451,7 +443,7 @@ func (pk *Packet) PubrecDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
@@ -470,7 +462,7 @@ func (pk *Packet) PubrelDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -495,7 +487,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
// Get Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, offset)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Get Granted QOS flags.
@@ -542,7 +534,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
@@ -552,7 +544,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
var topic string
topic, offset, err = decodeString(buf, offset)
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
pk.Topics = append(pk.Topics, topic)
@@ -560,7 +552,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
var qos byte
qos, offset, err = decodeByte(buf, offset)
if err != nil {
return ErrMalformedQoS
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
}
// Ensure QoS byte is within range.
@@ -599,7 +591,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
var err error
pk.PacketID, _, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
return nil
}
@@ -641,7 +633,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
// Get the Packet ID.
pk.PacketID, offset, err = decodeUint16(buf, 0)
if err != nil {
return ErrMalformedPacketID
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
}
// Keep decoding until there's no space left.
@@ -649,7 +641,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
var t string
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
if err != nil {
return ErrMalformedTopic
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
}
if len(t) > 0 {
@@ -671,3 +663,8 @@ func (pk *Packet) UnsubscribeValidate() (byte, error) {
return Accepted, nil
}
// FormatID returns the PacketID field as a decimal integer.
func (pk *Packet) FormatID() string {
return strconv.FormatUint(uint64(pk.PacketID), 10)
}

View File

@@ -6,7 +6,7 @@ type packetTestData struct {
actualBytes []byte // the actual byte array that is created in the event of a byte mutation (eg. MQTT-2.3.1-1 qos/packet id)
packet *Packet // the packet that is expected
desc string // a description of the test
failFirst interface{} // expected fail result to be run immediately after the method is called
failFirst error // expected fail result to be run immediately after the method is called
expect interface{} // generic expected fail result to be checked
isolate bool // isolate can be used to isolate a test
primary bool // primary is a test that should be run using readPackets

View File

@@ -2,6 +2,8 @@ package packets
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/jinzhu/copier"
@@ -74,13 +76,13 @@ func TestConnectDecode(t *testing.T) {
}
require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc)
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficient bytes in packet [i:%d] %s", i, wanted.desc)
pk := &Packet{FixedHeader: FixedHeader{Type: Connect}}
err := pk.ConnectDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -180,7 +182,7 @@ func TestConnackDecode(t *testing.T) {
err := pk.ConnackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -339,7 +341,7 @@ func TestPubackDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -407,7 +409,7 @@ func TestPubcompDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -494,7 +496,7 @@ func TestPublishDecode(t *testing.T) {
err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected fh error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -630,7 +632,7 @@ func TestPubrecDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -703,7 +705,7 @@ func TestPubrelDecode(t *testing.T) {
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -775,7 +777,7 @@ func TestSubackDecode(t *testing.T) {
err := pk.SubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -854,7 +856,7 @@ func TestSubscribeDecode(t *testing.T) {
err := pk.SubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -957,7 +959,7 @@ func TestUnsubackDecode(t *testing.T) {
err := pk.UnsubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -1034,7 +1036,7 @@ func TestUnsubscribeDecode(t *testing.T) {
err := pk.UnsubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
continue
}
@@ -1080,3 +1082,10 @@ func BenchmarkUnsubscribeValidate(b *testing.B) {
pk.UnsubscribeValidate()
}
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}

View File

@@ -27,27 +27,28 @@ func New() *Index {
}
// RetainMessage saves a message payload to the end of a topic branch. Returns
// 1 if a retained message was added, 0 if there was no change, and -1 if the
// retained message was removed.
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *Index) RetainMessage(msg packets.Packet) int64 {
var q int64
x.mu.Lock()
defer x.mu.Unlock()
n := x.poperate(msg.TopicName)
// If there is a payload, we can store it.
if len(msg.Payload) > 0 {
if n.Message.FixedHeader.Retain == false {
q = 1
}
n.Message = msg
} else {
if n.Message.FixedHeader.Retain == true {
q = -1
}
x.unpoperate(msg.TopicName, "", true)
return 1
}
return q
// Otherwise, we are unsetting it.
// If there was a previous retained message, return -1 instead of 0.
var r int64 = 0
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
r = -1
}
x.unpoperate(msg.TopicName, "", true)
return r
}
// Subscribe creates a subscription filter for a client. Returns true if the
@@ -175,12 +176,12 @@ func (x *Index) Messages(filter string) []packets.Packet {
// Leaf is a child node on the tree.
type Leaf struct {
Message packets.Packet // a message which has been retained for a specific topic.
Key string // the key that was used to create the leaf.
Filter string // the path of the topic filter being matched.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
Filter string // the path of the topic filter being matched.
Message packets.Packet // a message which has been retained for a specific topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who

View File

@@ -113,8 +113,10 @@ func TestRetainMessage(t *testing.T) {
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
q = index.RetainMessage(pk2) // already exsiting
require.Equal(t, int64(0), q)
// The same message already exists, but we're not doing a deep-copy check, so it's considered
// to be a new message.
q = index.RetainMessage(pk2)
require.Equal(t, int64(1), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
@@ -126,6 +128,14 @@ func TestRetainMessage(t *testing.T) {
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
// Second Delete retained
q = index.RetainMessage(pk3)
require.Equal(t, int64(0), q)
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
}
func BenchmarkRetainMessage(b *testing.B) {

View File

@@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@@ -3,7 +3,7 @@ package auth
// Allow is an auth controller which allows access to all connections and topics.
type Allow struct{}
// Auth returns true if a username and password are acceptable. Allow always
// Authenticate returns true if a username and password are acceptable. Allow always
// returns true.
func (a *Allow) Authenticate(user, password []byte) bool {
return true
@@ -18,7 +18,7 @@ func (a *Allow) ACL(user []byte, topic string, write bool) bool {
// Disallow is an auth controller which disallows access to all connections and topics.
type Disallow struct{}
// Auth returns true if a username and password are acceptable. Disallow always
// Authenticate returns true if a username and password are acceptable. Disallow always
// returns false.
func (d *Disallow) Authenticate(user, password []byte) bool {
return false

View File

@@ -18,11 +18,11 @@ import (
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
address string // the network address to bind to.
listen *http.Server // the http server.
end int64 // ensure the close methods are only called once.}
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -38,10 +38,10 @@ type Listener interface {
// Listeners contains the network listeners for the broker.
type Listeners struct {
sync.RWMutex
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.

View File

@@ -22,11 +22,11 @@ func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
Config *Config // configuration for the listener.
address string // the network address the listener binds to.
Listening bool // indiciate the listener is listening.
Serving bool // indicate the listener is serving.
Config *Config // configuration for the listener.
done chan bool // indicate the listener is done.
Serving bool // indicate the listener is serving.
Listening bool // indiciate the listener is listening.
ErrListen bool // throw an error on listen.
}
@@ -44,15 +44,12 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for {
select {
case <-l.done:
return
}
for range l.done {
return
}
}
// SetConfig sets the configuration values of the mock listener.
// Listen begins listening for incoming traffic.
func (l *MockListener) Listen(s *system.Info) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
@@ -95,7 +92,7 @@ func (l *MockListener) IsServing() bool {
return l.Serving
}
// IsServing indicates whether the mock listener is listening.
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()

View File

@@ -14,11 +14,11 @@ import (
type TCP struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
end int64 // ensure the close methods are only called once.
config *Config // configuration values for the listener.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
var err error
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadInt64(&l.end) == 1 {
if atomic.LoadUint32(&l.end) == 1 {
return
}
@@ -93,8 +95,10 @@ func (l *TCP) Serve(establish EstablishFunc) {
return
}
if atomic.LoadInt64(&l.end) == 0 {
go establish(l.id, conn, l.config.Auth)
if atomic.LoadUint32(&l.end) == 0 {
go func() {
_ = establish(l.id, conn, l.config.Auth)
}()
}
}
}
@@ -104,8 +108,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}

View File

@@ -17,12 +17,14 @@ import (
)
var (
ErrInvalidMessage = errors.New("Message type not binary")
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
@@ -30,11 +32,11 @@ var (
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
end int64 // ensure the close methods are only called once.
establish EstablishFunc // the server's establish conection handler.
establish EstablishFunc // the server's establish connection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
@@ -161,8 +163,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -1,10 +1,8 @@
package bolt
import (
"errors"
"time"
"fmt"
"time"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
@@ -14,12 +12,17 @@ import (
)
const (
defaultPath = "mochi.db"
// defaultPath is the default file to use to store the data.
defaultPath = "mochi.db"
// defaultTimeout is the default timeout of the file lock.
defaultTimeout = 250 * time.Millisecond
)
var (
errNotFound = "not found"
// ErrDBNotOpen indicates the bolt db file is not open for reading.
ErrDBNotOpen = fmt.Errorf("boltdb not opened")
)
// Store is a backend for writing and reading to bolt persistent storage.
@@ -66,7 +69,7 @@ func (s *Store) Close() {
// WriteServerInfo writes the server info to the boltdb instance.
func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -79,7 +82,7 @@ func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
// WriteSubscription writes a single subscription to the boltdb instance.
func (s *Store) WriteSubscription(v persistence.Subscription) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -92,7 +95,7 @@ func (s *Store) WriteSubscription(v persistence.Subscription) error {
// WriteInflight writes a single inflight message to the boltdb instance.
func (s *Store) WriteInflight(v persistence.Message) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -105,7 +108,7 @@ func (s *Store) WriteInflight(v persistence.Message) error {
// WriteRetained writes a single retained message to the boltdb instance.
func (s *Store) WriteRetained(v persistence.Message) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -118,7 +121,7 @@ func (s *Store) WriteRetained(v persistence.Message) error {
// WriteClient writes a single client to the boltdb instance.
func (s *Store) WriteClient(v persistence.Client) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.Save(&v)
@@ -131,7 +134,7 @@ func (s *Store) WriteClient(v persistence.Client) error {
// DeleteSubscription deletes a subscription from the boltdb instance.
func (s *Store) DeleteSubscription(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Subscription{
@@ -147,7 +150,7 @@ func (s *Store) DeleteSubscription(id string) error {
// DeleteClient deletes a client from the boltdb instance.
func (s *Store) DeleteClient(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Client{
@@ -163,7 +166,7 @@ func (s *Store) DeleteClient(id string) error {
// DeleteInflight deletes an inflight message from the boltdb instance.
func (s *Store) DeleteInflight(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
@@ -179,7 +182,7 @@ func (s *Store) DeleteInflight(id string) error {
// DeleteRetained deletes a retained message from the boltdb instance.
func (s *Store) DeleteRetained(id string) error {
if s.db == nil {
return errors.New("boltdb not opened")
return ErrDBNotOpen
}
err := s.db.DeleteStruct(&persistence.Message{
@@ -195,11 +198,11 @@ func (s *Store) DeleteRetained(id string) error {
// ReadSubscriptions loads all the subscriptions from the boltdb instance.
func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KSubscription, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -209,11 +212,11 @@ func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
// ReadClients loads all the clients from the boltdb instance.
func (s *Store) ReadClients() (v []persistence.Client, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KClient, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -223,11 +226,11 @@ func (s *Store) ReadClients() (v []persistence.Client, err error) {
// ReadInflight loads all the inflight messages from the boltdb instance.
func (s *Store) ReadInflight() (v []persistence.Message, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KInflight, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -237,11 +240,11 @@ func (s *Store) ReadInflight() (v []persistence.Message, err error) {
// ReadRetained loads all the retained messages from the boltdb instance.
func (s *Store) ReadRetained() (v []persistence.Message, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.Find("T", persistence.KRetained, &v)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}
@@ -251,12 +254,11 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
//ReadServerInfo loads the server info from the boltdb instance.
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
if s.db == nil {
return v, errors.New("boltdb not opened")
return v, ErrDBNotOpen
}
err = s.db.One("ID", persistence.KServerInfo, &v)
fmt.Println(err)
if err != nil && err.Error() != errNotFound {
if err != nil && err != storm.ErrNotFound {
return
}

View File

@@ -7,11 +7,21 @@ import (
)
const (
// KSubscription is the key for subscription data.
KSubscription = "sub"
KServerInfo = "srv"
KRetained = "ret"
KInflight = "ifm"
KClient = "cl"
// KServerInfo is the key for server info data.
KServerInfo = "srv"
// KRetained is the key for retained messages data.
KRetained = "ret"
// KInflight is the key for inflight messages data.
KInflight = "ifm"
// KClient is the key for client data.
KClient = "cl"
)
// Store is an interface which details a persistent storage connector.
@@ -53,50 +63,50 @@ type Subscription struct {
// Message contains the details of a retained or inflight message.
type Message struct {
ID string // the storage key.
T string // the type of the stored data.
Client string // the id of the client who sent the message (if inflight).
FixedHeader FixedHeader // the header properties of the message.
PacketID uint16 // the unique id of the packet (if inflight).
TopicName string // the topic the message was sent to (if retained).
Payload []byte // the message payload (if retained).
FixedHeader FixedHeader // the header properties of the message.
T string // the type of the stored data.
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
}
// FixedHeader contains the fixed header properties of a message.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Client contains client data that can be persistently stored.
type Client struct {
LWT LWT // the last-will-and-testament message for the client.
Username []byte // the username the client authenticated with.
ID string // the storage key.
ClientID string // the id of the client.
T string // the type of the stored data.
Listener string // the last known listener id for the client
Username []byte // the username the client authenticated with.
LWT LWT // the last-will-and-testament message for the client.
}
// LWT contains details about a clients LWT payload.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
Fail map[string]bool // issue errors for different methods.
}
// Open opens the storage instance.
@@ -197,7 +207,7 @@ func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) {
}
return []Subscription{
Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
@@ -214,7 +224,7 @@ func (s *MockStore) ReadClients() (v []Client, err error) {
}
return []Client{
Client{
{
ID: "cl_client1",
ClientID: "client1",
T: KClient,
@@ -230,7 +240,7 @@ func (s *MockStore) ReadInflight() (v []Message, err error) {
}
return []Message{
Message{
{
ID: "client1_if_100",
T: KInflight,
Client: "client1",
@@ -250,7 +260,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
}
return []Message{
Message{
{
ID: "client1_ret_200",
T: KRetained,
FixedHeader: FixedHeader{

View File

@@ -1,9 +1,10 @@
// packet server provides a MQTT 3.1.1 compliant MQTT server.
// package server provides a MQTT 3.1.1 compliant MQTT server.
package server
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"sync/atomic"
@@ -14,6 +15,7 @@ import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/internal/utils"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence"
@@ -21,14 +23,41 @@ import (
)
const (
Version = "1.0.1" // the server version.
// Version indicates the current server version.
Version = "1.1.1"
)
var (
ErrListenerIDExists = errors.New("Listener id already exists")
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics")
// ErrListenerIDExists indicates that a listener with the same id already exists.
ErrListenerIDExists = errors.New("listener id already exists")
// ErrReadConnectInvalid indicates that the connection packet was invalid.
ErrReadConnectInvalid = errors.New("connect packet was not valid")
// ErrConnectNotAuthorized indicates that the connection packet had incorrect auth values.
ErrConnectNotAuthorized = errors.New("connect packet was not authorized")
// ErrInvalidTopic indicates that the specified topic was not valid.
ErrInvalidTopic = errors.New("cannot publish to $ and $SYS topics")
// ErrRejectPacket indicates that a packet should be dropped instead of processed.
ErrRejectPacket = errors.New("packet rejected")
// ErrClientDisconnect indicates that a client disconnected from the server.
ErrClientDisconnect = errors.New("client disconnected")
// ErrClientReconnect indicates that a client attempted to reconnect while still connected.
ErrClientReconnect = errors.New("client sent connect while connected")
// ErrServerShutdown is propagated when the server shuts down.
ErrServerShutdown = errors.New("server is shutting down")
// ErrSessionReestablished indicates that an existing client was replaced by a newly connected
// client. The existing client is disconnected.
ErrSessionReestablished = errors.New("client session re-established")
// ErrConnectionFailed indicates that a client connection attempt failed for other reasons.
ErrConnectionFailed = errors.New("connection attempt failed")
// SysTopicInterval is the number of milliseconds between $SYS topic publishes.
SysTopicInterval time.Duration = 30000
@@ -44,16 +73,26 @@ var (
// Server is an MQTT broker server. It should be created with server.New()
// in order to ensure all the internal fields are correctly populated.
type Server struct {
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Options *Options // configurable server options.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
Store persistence.Store // a persistent storage backend if desired.
done chan bool // indicate that the server is ending.
bytepool circ.BytesPool // a byte pool for incoming and outgoing packets.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
done chan bool // indicate that the server is ending.
}
// Options contains configurable options for the server.
type Options struct {
// BufferSize overrides the default buffer size (circ.DefaultBufferSize) for the client buffers.
BufferSize int
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
BufferBlockSize int
}
// inlineMessages contains channels for handling inline (direct) publishing.
@@ -62,11 +101,22 @@ type inlineMessages struct {
pub chan packets.Packet // a channel of packets to publish to clients
}
// New returns a new instance of an MQTT broker.
// New returns a new instance of MQTT server with no options.
// This method has been deprecated and will be removed in a future release.
// Please use NewServer instead.
func New() *Server {
return NewServer(nil)
}
// NewServer returns a new instance of an MQTT broker with optional values where applicable.
func NewServer(opts *Options) *Server {
if opts == nil {
opts = new(Options)
}
s := &Server{
done: make(chan bool),
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
bytepool: circ.NewBytesPool(opts.BufferSize),
Clients: clients.New(),
Topics: topics.New(),
System: &system.Info{
@@ -78,7 +128,8 @@ func New() *Server {
done: make(chan bool),
pub: make(chan packets.Packet, 1024),
},
Events: events.Events{},
Events: events.Events{},
Options: opts,
}
// Expose server stats using the system listener so it can be used in the
@@ -165,118 +216,200 @@ func (s *Server) inlineClient() {
}
}
// readConnectionPacket reads the first incoming header for a connection, and if
// acceptable, returns the valid connection packet.
func (s *Server) readConnectionPacket(cl *clients.Client) (pk packets.Packet, err error) {
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return
}
pk, err = cl.ReadPacket(fh)
if err != nil {
return
}
if pk.FixedHeader.Type != packets.Connect {
return pk, ErrReadConnectInvalid
}
return
}
// onError is a pass-through method which triggers the OnError
// event hook (if applicable), and returns the provided error.
func (s *Server) onError(cl events.Client, err error) error {
if err == nil {
return err
}
// Note: if the error originates from a real cause, it will
// have been captured as the StopCause. The two cases ignored
// below are ordinary consequences of closing the connection.
// If one of these ordinary conditions stops the connection,
// then the client closed or broke the connection.
if s.Events.OnError != nil &&
!errors.Is(err, io.EOF) {
s.Events.OnError(cl, err)
}
return err
}
// onStorage is a pass-through method which delegates errors from
// the persistent storage adapter to the onError event hook.
func (s *Server) onStorage(cl events.Clientlike, err error) {
if err == nil {
return
}
_ = s.onError(cl.Info(), fmt.Errorf("storage: %w", err))
}
// EstablishConnection establishes a new client when a listener
// accepts a new connection.
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
xbr := s.bytepool.Get() // Get byte buffer from pools for receiving packet data.
xbw := s.bytepool.Get() // and for sending.
defer s.bytepool.Put(xbr)
defer s.bytepool.Put(xbw)
cl := clients.NewClient(c,
circ.NewReaderFromSlice(0, xbr),
circ.NewWriterFromSlice(0, xbw),
circ.NewReaderFromSlice(s.Options.BufferBlockSize, xbr),
circ.NewWriterFromSlice(s.Options.BufferBlockSize, xbw),
s.System,
)
cl.Start()
defer cl.ClearBuffers()
defer cl.Stop(nil)
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
pk, err := s.readConnectionPacket(cl)
if err != nil {
return err
return s.onError(cl.Info(), fmt.Errorf("read connection: %w", err))
}
pk, err := cl.ReadPacket(fh)
ackCode, err := pk.ConnectValidate()
if err != nil {
return err
if err := s.ackConnection(cl, ackCode, false); err != nil {
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
}
return s.onError(cl.Info(), fmt.Errorf("validate connection packet: %w", err))
}
if pk.FixedHeader.Type != packets.Connect {
return ErrReadConnectInvalid
}
cl.Identify(lid, pk, ac) // Set client identity values from the connection packet.
cl.Identify(lid, pk, ac)
retcode, _ := pk.ConnectValidate()
if !ac.Authenticate(pk.Username, pk.Password) {
retcode = packets.CodeConnectBadAuthValues
if err := s.ackConnection(cl, packets.CodeConnectBadAuthValues, false); err != nil {
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
}
return s.onError(cl.Info(), ErrConnectionFailed)
}
atomic.AddInt64(&s.System.ConnectionsTotal, 1)
atomic.AddInt64(&s.System.ClientsConnected, 1)
defer atomic.AddInt64(&s.System.ClientsConnected, -1)
defer atomic.AddInt64(&s.System.ClientsDisconnected, 1)
var sessionPresent bool
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
if atomic.LoadInt64(&existing.State.Done) == 1 {
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
}
existing.Stop()
if pk.CleanSession {
for k := range existing.Subscriptions {
delete(existing.Subscriptions, k)
q := s.Topics.Unsubscribe(k, existing.ID)
if q {
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
} else {
cl.Inflight = existing.Inflight // Inherit from existing session.
cl.Subscriptions = existing.Subscriptions
sessionPresent = true
}
existing.Unlock()
} else {
atomic.AddInt64(&s.System.ClientsTotal, 1)
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
atomic.AddInt64(&s.System.ClientsMax, 1)
}
sessionPresent := s.inheritClientSession(pk, cl)
s.Clients.Add(cl)
err = s.ackConnection(cl, ackCode, sessionPresent)
if err != nil {
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
}
s.Clients.Add(cl) // Overwrite any existing client with the same name.
err = s.writeClient(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: sessionPresent,
ReturnCode: retcode,
})
if err != nil || retcode != packets.Accepted {
return err
err = s.ResendClientInflight(cl, true)
if err != nil {
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
}
s.ResendClientInflight(cl, true)
if s.Store != nil {
s.Store.WriteClient(persistence.Client{
s.onStorage(cl, s.Store.WriteClient(persistence.Client{
ID: "cl_" + cl.ID,
ClientID: cl.ID,
T: persistence.KClient,
Listener: cl.Listener,
Username: cl.Username,
LWT: persistence.LWT(cl.LWT),
})
}))
}
err = cl.Read(s.processPacket)
if err != nil {
s.closeClient(cl, true)
if s.Events.OnConnect != nil {
s.Events.OnConnect(cl.Info(), events.Packet(pk))
}
s.bytepool.Put(xbr) // Return byte buffers to pools when the client has finished.
s.bytepool.Put(xbw)
if err := cl.Read(s.processPacket); err != nil {
s.sendLWT(cl)
cl.Stop(err)
}
atomic.AddInt64(&s.System.ClientsConnected, -1)
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
err = cl.StopCause() // Determine true cause of stop.
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(cl.Info(), err)
}
return err
}
// ackConnection returns a Connack packet to a client.
func (s *Server) ackConnection(cl *clients.Client, ack byte, present bool) error {
return s.writeClient(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: present,
ReturnCode: ack,
})
}
// inheritClientSession inherits the state of an existing client sharing the same
// connection ID. If cleanSession is true, the state of any previously existing client
// session is abandoned.
func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) bool {
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
defer existing.Unlock()
existing.Stop(ErrSessionReestablished) // Issue a stop on the old client.
// Per [MQTT-3.1.2-6]:
// If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one.
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
if pk.CleanSession || existing.CleanSession {
s.unsubscribeClient(existing)
return false
}
cl.Inflight = existing.Inflight // Take address of existing session.
cl.Subscriptions = existing.Subscriptions
return true
} else {
atomic.AddInt64(&s.System.ClientsTotal, 1)
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
atomic.AddInt64(&s.System.ClientsMax, 1)
}
return false
}
}
// unsubscribeClient unsubscribes a client from all of their subscriptions.
func (s *Server) unsubscribeClient(cl *clients.Client) {
for k := range cl.Subscriptions {
delete(cl.Subscriptions, k)
if s.Topics.Unsubscribe(k, cl.ID) {
atomic.AddInt64(&s.System.Subscriptions, -1)
}
}
}
// writeClient writes packets to a client connection.
func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
_, err := cl.WritePacket(pk)
if err != nil {
return err
return fmt.Errorf("write: %w", err)
}
return nil
@@ -327,13 +460,14 @@ func (s *Server) processPacket(cl *clients.Client, pk packets.Packet) error {
// establish a new connection on an existing connection. See EstablishConnection
// instead.
func (s *Server) processConnect(cl *clients.Client, pk packets.Packet) error {
s.closeClient(cl, true)
s.sendLWT(cl)
cl.Stop(ErrClientReconnect)
return nil
}
// processDisconnect processes a Disconnect packet.
func (s *Server) processDisconnect(cl *clients.Client, pk packets.Packet) error {
s.closeClient(cl, false)
cl.Stop(ErrClientDisconnect)
return nil
}
@@ -362,14 +496,15 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Type: packets.Publish,
Retain: retain,
},
TopicName: topic,
Payload: payload,
}
if retain {
s.retainMessage(pk)
s.retainMessage(&s.inline, pk)
}
// handoff packet to s.inline.pub channel for writing to client buffers
@@ -379,6 +514,16 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
return nil
}
// Info provides pseudo-client information for the inline messages processor.
// It provides a 'client' to which inline retained messages can be assigned.
func (*inlineMessages) Info() events.Client {
return events.Client{
ID: "inline",
Remote: "inline",
Listener: "inline",
}
}
// processPublish processes a Publish packet.
func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
if len(pk.TopicName) >= 4 && pk.TopicName[0:4] == "$SYS" {
@@ -389,8 +534,25 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
return nil
}
// if an OnProcessMessage hook exists, potentially modify the packet.
if s.Events.OnProcessMessage != nil {
pkx, err := s.Events.OnProcessMessage(cl.Info(), events.Packet(pk))
if err == nil {
pk = packets.Packet(pkx) // Only use the new package changes if there's no errors.
} else {
// If the ErrRejectPacket is return, abandon processing the packet.
if err == ErrRejectPacket {
return nil
}
if s.Events.OnError != nil {
s.Events.OnError(cl.Info(), err)
}
}
}
if pk.FixedHeader.Retain {
s.retainMessage(pk)
s.retainMessage(cl, pk)
}
if pk.FixedHeader.Qos > 0 {
@@ -407,12 +569,12 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// omit errors in case of broken connection / LWT publish. ack send failures
// will be handled by in-flight resending on next reconnect.
s.writeClient(cl, ack)
s.onError(cl.Info(), s.writeClient(cl, ack))
}
// if an OnMessage hook exists, potentially modify the packet.
if s.Events.OnMessage != nil {
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
if pkx, err := s.Events.OnMessage(cl.Info(), events.Packet(pk)); err == nil {
pk = packets.Packet(pkx)
}
}
@@ -425,21 +587,23 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// retainMessage adds a message to a topic, and if a persistent store is provided,
// adds the message to the store so it can be reloaded if necessary.
func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) retainMessage(cl events.Clientlike, pk packets.Packet) {
out := pk.PublishCopy()
q := s.Topics.RetainMessage(out)
atomic.AddInt64(&s.System.Retained, q)
r := s.Topics.RetainMessage(out)
atomic.AddInt64(&s.System.Retained, r)
if s.Store != nil {
if q == 1 {
s.Store.WriteRetained(persistence.Message{
ID: "ret_" + out.TopicName,
id := "ret_" + out.TopicName
if r == 1 {
s.onStorage(cl, s.Store.WriteRetained(persistence.Message{
ID: id,
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
})
}))
} else {
s.Store.DeleteRetained("ret_" + out.TopicName)
s.onStorage(cl, s.Store.DeleteRetained(id))
}
}
}
@@ -449,6 +613,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) publishToSubscribers(pk packets.Packet) {
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
if client, ok := s.Clients.Get(id); ok {
// If the AllowClients value is set, only deliver the packet if the subscribed
// client exists in the AllowClients value. For use with the OnMessage event hook
// in cases where you want to publish messages to clients selectively.
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
continue
}
out := pk.PublishCopy()
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
out.FixedHeader.Qos = qos
@@ -473,18 +645,18 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + client.ID + "_" + strconv.Itoa(int(out.PacketID)),
T: persistence.KRetained,
s.onStorage(client, s.Store.WriteInflight(persistence.Message{
ID: persistentID(client, out),
T: persistence.KInflight,
FixedHeader: persistence.FixedHeader(out.FixedHeader),
TopicName: out.TopicName,
Payload: out.Payload,
Sent: sent,
})
}))
}
}
s.writeClient(client, out)
s.onError(client.Info(), s.writeClient(client, out))
}
}
}
@@ -496,7 +668,7 @@ func (s *Server) processPuback(cl *clients.Client, pk packets.Packet) error {
atomic.AddInt64(&s.System.Inflight, -1)
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
}
@@ -538,7 +710,7 @@ func (s *Server) processPubrel(cl *clients.Client, pk packets.Packet) error {
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
@@ -551,7 +723,7 @@ func (s *Server) processPubcomp(cl *clients.Client, pk packets.Packet) error {
atomic.AddInt64(&s.System.Inflight, -1)
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
}
return nil
}
@@ -571,13 +743,13 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
retCodes[i] = pk.Qoss[i]
if s.Store != nil {
s.Store.WriteSubscription(persistence.Subscription{
s.onStorage(cl, s.Store.WriteSubscription(persistence.Subscription{
ID: "sub_" + cl.ID + ":" + pk.Topics[i],
T: persistence.KSubscription,
Filter: pk.Topics[i],
Client: cl.ID,
QoS: pk.Qoss[i],
})
}))
}
}
}
@@ -593,10 +765,15 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
return err
}
// Publish out any retained messages matching the subscription filter.
// Publish out any retained messages matching the subscription filter and the user has
// been allowed to subscribe to.
for i := 0; i < len(pk.Topics); i++ {
if retCodes[i] == packets.ErrSubAckNetworkError {
continue
}
for _, pkv := range s.Topics.Messages(pk.Topics[i]) {
s.writeClient(cl, pkv) // omit errors, prefer continuing.
s.onError(cl.Info(), s.writeClient(cl, pkv))
}
}
@@ -626,6 +803,17 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
return nil
}
// atomicItoa reads an *int64 and formats a decimal string.
func atomicItoa(ptr *int64) string {
return strconv.FormatInt(atomic.LoadInt64(ptr), 10)
}
// persistentID return a string combining the client and packet
// identifiers for use with the persistence layer.
func persistentID(client *clients.Client, pk packets.Packet) string {
return "if_" + client.ID + "_" + pk.FormatID()
}
// publishSysTopics publishes the current values to the server $SYS topics.
// Due to the int to string conversions this method is not as cheap as
// some of the others so the publishing interval should be set appropriately.
@@ -637,26 +825,27 @@ func (s *Server) publishSysTopics() {
},
}
s.System.Uptime = time.Now().Unix() - s.System.Started
uptime := time.Now().Unix() - atomic.LoadInt64(&s.System.Started)
atomic.StoreInt64(&s.System.Uptime, uptime)
topics := map[string]string{
"$SYS/broker/version": s.System.Version,
"$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)),
"$SYS/broker/timestamp": strconv.Itoa(int(s.System.Started)),
"$SYS/broker/load/bytes/received": strconv.Itoa(int(s.System.BytesRecv)),
"$SYS/broker/load/bytes/sent": strconv.Itoa(int(s.System.BytesSent)),
"$SYS/broker/clients/connected": strconv.Itoa(int(s.System.ClientsConnected)),
"$SYS/broker/clients/disconnected": strconv.Itoa(int(s.System.ClientsDisconnected)),
"$SYS/broker/clients/maximum": strconv.Itoa(int(s.System.ClientsMax)),
"$SYS/broker/clients/total": strconv.Itoa(int(s.System.ClientsTotal)),
"$SYS/broker/connections/total": strconv.Itoa(int(s.System.ConnectionsTotal)),
"$SYS/broker/messages/received": strconv.Itoa(int(s.System.MessagesRecv)),
"$SYS/broker/messages/sent": strconv.Itoa(int(s.System.MessagesSent)),
"$SYS/broker/messages/publish/dropped": strconv.Itoa(int(s.System.PublishDropped)),
"$SYS/broker/messages/publish/received": strconv.Itoa(int(s.System.PublishRecv)),
"$SYS/broker/messages/publish/sent": strconv.Itoa(int(s.System.PublishSent)),
"$SYS/broker/messages/retained/count": strconv.Itoa(int(s.System.Retained)),
"$SYS/broker/messages/inflight": strconv.Itoa(int(s.System.Inflight)),
"$SYS/broker/subscriptions/count": strconv.Itoa(int(s.System.Subscriptions)),
"$SYS/broker/uptime": atomicItoa(&s.System.Uptime),
"$SYS/broker/timestamp": atomicItoa(&s.System.Started),
"$SYS/broker/load/bytes/received": atomicItoa(&s.System.BytesRecv),
"$SYS/broker/load/bytes/sent": atomicItoa(&s.System.BytesSent),
"$SYS/broker/clients/connected": atomicItoa(&s.System.ClientsConnected),
"$SYS/broker/clients/disconnected": atomicItoa(&s.System.ClientsDisconnected),
"$SYS/broker/clients/maximum": atomicItoa(&s.System.ClientsMax),
"$SYS/broker/clients/total": atomicItoa(&s.System.ClientsTotal),
"$SYS/broker/connections/total": atomicItoa(&s.System.ConnectionsTotal),
"$SYS/broker/messages/received": atomicItoa(&s.System.MessagesRecv),
"$SYS/broker/messages/sent": atomicItoa(&s.System.MessagesSent),
"$SYS/broker/messages/publish/dropped": atomicItoa(&s.System.PublishDropped),
"$SYS/broker/messages/publish/received": atomicItoa(&s.System.PublishRecv),
"$SYS/broker/messages/publish/sent": atomicItoa(&s.System.PublishSent),
"$SYS/broker/messages/retained/count": atomicItoa(&s.System.Retained),
"$SYS/broker/messages/inflight": atomicItoa(&s.System.Inflight),
"$SYS/broker/subscriptions/count": atomicItoa(&s.System.Subscriptions),
}
for topic, payload := range topics {
@@ -668,10 +857,10 @@ func (s *Server) publishSysTopics() {
}
if s.Store != nil {
s.Store.WriteServerInfo(persistence.ServerInfo{
s.onStorage(&s.inline, s.Store.WriteServerInfo(persistence.ServerInfo{
Info: *s.System,
ID: persistence.KServerInfo,
})
}))
}
}
@@ -691,7 +880,7 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
}
if s.Store != nil {
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)))
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, tk.Packet)))
}
continue
@@ -715,15 +904,15 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
}
if s.Store != nil {
s.Store.WriteInflight(persistence.Message{
ID: "if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)),
T: persistence.KRetained,
s.onStorage(cl, s.Store.WriteInflight(persistence.Message{
ID: persistentID(cl, tk.Packet),
T: persistence.KInflight,
FixedHeader: persistence.FixedHeader(tk.Packet.FixedHeader),
TopicName: tk.Packet.TopicName,
Payload: tk.Packet.Payload,
Sent: tk.Sent,
Resends: tk.Resends,
})
}))
}
}
@@ -746,15 +935,14 @@ func (s *Server) Close() error {
func (s *Server) closeListenerClients(listener string) {
clients := s.Clients.GetByListener(listener)
for _, cl := range clients {
s.closeClient(cl, false) // omit errors
cl.Stop(ErrServerShutdown)
}
}
// closeClient closes a client connection and publishes any LWT messages.
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
if sendLWT && cl.LWT.Topic != "" {
s.processPublish(cl, packets.Packet{
// sendLWT issues an LWT message to a topic when a client disconnects.
func (s *Server) sendLWT(cl *clients.Client) error {
if cl.LWT.Topic != "" {
err := s.processPublish(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Retain: cl.LWT.Retain,
@@ -763,10 +951,11 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
TopicName: cl.LWT.Topic,
Payload: cl.LWT.Message,
})
if err != nil {
return s.onError(cl.Info(), fmt.Errorf("send lwt: %s %w; %+v", cl.ID, err, cl.LWT))
}
}
cl.Stop()
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
package system
import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestInfoAlignment(t *testing.T) {
typ := reflect.TypeOf(Info{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
switch f.Type.Kind() {
case reflect.Int64, reflect.Uint64:
require.Equalf(t, uintptr(0), f.Offset%8,
"%s requires 64-bit alignment for atomic: offset %d",
f.Name, f.Offset)
}
}
}

View File

@@ -3,13 +3,11 @@ language: go
before_install:
- go get github.com/stretchr/testify
env:
GO111MODULE=on
env: GO111MODULE=on
go:
- '1.11.x'
- '1.12.x'
- '1.13.x'
- "1.13.x"
- "1.14.x"
- tip
matrix:

View File

@@ -49,7 +49,7 @@ _For extended queries and support for [Badger](https://github.com/dgraph-io/badg
## Getting Started
```bash
go get -u github.com/asdine/storm
GO111MODULE=on go get -u github.com/asdine/storm/v3
```
## Import Storm

View File

@@ -1,20 +0,0 @@
module github.com/asdine/storm/v3
require (
github.com/DataDog/zstd v1.4.1 // indirect
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.3.2
github.com/golang/snappy v0.0.1 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
github.com/vmihailenco/msgpack v4.0.4+incompatible
go.etcd.io/bbolt v1.3.3
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 // indirect
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179 // indirect
google.golang.org/appengine v1.6.5 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)
go 1.13

View File

@@ -1,40 +0,0 @@
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179 h1:IqVhUQp5B9ARnZUcfqXy6zP+A+YuPpP7IFo8gFeCOzU=
golang.org/x/sys v0.0.0-20191105142833-ac3223d80179/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -6,9 +6,16 @@
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
---
⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)**
---
### Documentation
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
@@ -30,35 +37,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View File

@@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
}
// A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
// NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@@ -65,6 +73,8 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
req := &http.Request{
Method: "GET",
Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
@@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
@@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
if u.Scheme == "https" {
if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
@@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
@@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
@@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
return cfg.Clone()
}

View File

@@ -1,16 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@@ -1,38 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
@@ -244,8 +245,8 @@ type Conn struct {
subprotocol string
// Write fields
mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
mu chan struct{} // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time
@@ -302,8 +303,8 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
writeBuf = make([]byte, writeBufferSize)
}
mu := make(chan bool, 1)
mu <- true
mu := make(chan struct{}, 1)
mu <- struct{}{}
c := &Conn{
isServer: isServer,
br: br,
@@ -377,7 +378,7 @@ func (c *Conn) read(n int) ([]byte, error) {
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu
defer func() { c.mu <- true }()
defer func() { c.mu <- struct{}{} }()
c.writeErrMu.Lock()
err := c.writeErr
@@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -429,7 +436,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
maskBytes(key, 0, buf[6:])
}
d := time.Hour * 1000
d := 1000 * time.Hour
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if d < 0 {
@@ -444,7 +451,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- true }()
defer func() { c.mu <- struct{}{} }()
c.writeErrMu.Lock()
err := c.writeErr
@@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true
} else {
errors = append(errors, "RSV1 set")
}
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
if rsv2 {
errors = append(errors, "RSV2 set")
}
if rsv3 {
errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
errors = append(errors, "len > 125 for control")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}
// 3. Read and parse frame length as per
@@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

View File

@@ -1,15 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

View File

@@ -1,18 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View File

@@ -187,9 +187,9 @@
// than the largest message do not provide any benefit.
//
// Depending on the distribution of message sizes, setting the buffer size to
// to a value less than the maximum expected message size can greatly reduce
// memory use with a small impact on performance. Here's an example: If 99% of
// the messages are smaller than 256 bytes and the maximum message size is 512
// a value less than the maximum expected message size can greatly reduce memory
// use with a small impact on performance. Here's an example: If 99% of the
// messages are smaller than 256 bytes and the maximum message size is 512
// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls
// than a buffer size of 512 bytes. The memory savings is 50%.
//

View File

@@ -1,3 +0,0 @@
module github.com/gorilla/websocket
go 1.12

View File

@@ -1,2 +0,0 @@
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=

View File

@@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build !appengine
// +build !appengine
package websocket

View File

@@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build appengine
// +build appengine
package websocket

View File

@@ -73,8 +73,8 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan bool, 1)
mu <- true
mu := make(chan struct{}, 1)
mu <- struct{}{}
var nc prepareConn
c := &Conn{
conn: &nc,

View File

@@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
connectReq := &http.Request{
Method: "CONNECT",
Method: http.MethodConnect,
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,

View File

@@ -23,6 +23,8 @@ func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
@@ -115,8 +117,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-WebSocket-Protocol).
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
@@ -131,7 +133,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}

21
vendor/github.com/gorilla/websocket/tls_handshake.go generated vendored Normal file
View File

@@ -0,0 +1,21 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,21 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@@ -1,19 +0,0 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

View File

@@ -1,12 +0,0 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View File

@@ -1,3 +0,0 @@
guard 'gotest' do
watch(%r{\.go$})
end

View File

@@ -2,7 +2,7 @@
I am a copier, I copy everything from one to another
[![wercker status](https://app.wercker.com/status/9d44ad2d4e6253929c8fb71359effc0b/s/master "wercker status")](https://app.wercker.com/project/byKey/9d44ad2d4e6253929c8fb71359effc0b)
[![test status](https://github.com/jinzhu/copier/workflows/tests/badge.svg?branch=master "test status")](https://github.com/jinzhu/copier/actions)
## Features
@@ -11,6 +11,10 @@
* Copy from field to method with same name
* Copy from slice to slice
* Copy from struct to slice
* Copy from map to map
* Enforce copying a field with a tag
* Ignore a field with a tag
* Deep Copy
## Usage
@@ -23,32 +27,45 @@ import (
)
type User struct {
Name string
Role string
Age int32
Name string
Role string
Age int32
EmployeCode int64 `copier:"EmployeNum"` // specify field name
// Explicitly ignored in the destination struct.
Salary int
}
func (user *User) DoubleAge() int32 {
return 2 * user.Age
}
// Tags in the destination Struct provide instructions to copier.Copy to ignore
// or enforce copying and to panic or return an error if a field was not copied.
type Employee struct {
Name string
Age int32
// Tell copier.Copy to panic if this field is not copied.
Name string `copier:"must"`
// Tell copier.Copy to return an error if this field is not copied.
Age int32 `copier:"must,nopanic"`
// Tell copier.Copy to explicitly ignore copying this field.
Salary int `copier:"-"`
DoubleAge int32
EmployeId int64
SuperRule string
EmployeId int64 `copier:"EmployeNum"` // specify field name
SuperRole string
}
func (employee *Employee) Role(role string) {
employee.SuperRule = "Super " + role
employee.SuperRole = "Super " + role
}
func main() {
var (
user = User{Name: "Jinzhu", Age: 18, Role: "Admin"}
users = []User{{Name: "Jinzhu", Age: 18, Role: "Admin"}, {Name: "jinzhu 2", Age: 30, Role: "Dev"}}
employee = Employee{}
user = User{Name: "Jinzhu", Age: 18, Role: "Admin", Salary: 200000}
users = []User{{Name: "Jinzhu", Age: 18, Role: "Admin", Salary: 100000}, {Name: "jinzhu 2", Age: 30, Role: "Dev", Salary: 60000}}
employee = Employee{Salary: 150000}
employees = []Employee{}
)
@@ -58,9 +75,10 @@ func main() {
// Employee{
// Name: "Jinzhu", // Copy from field
// Age: 18, // Copy from field
// Salary:150000, // Copying explicitly ignored
// DoubleAge: 36, // Copy from method
// EmployeeId: 0, // Ignored
// SuperRule: "Super Admin", // Copy to method
// SuperRole: "Super Admin", // Copy to method
// }
// Copy struct to slice
@@ -68,7 +86,7 @@ func main() {
fmt.Printf("%#v \n", employees)
// []Employee{
// {Name: "Jinzhu", Age: 18, DoubleAge: 36, EmployeId: 0, SuperRule: "Super Admin"}
// {Name: "Jinzhu", Age: 18, Salary:0, DoubleAge: 36, EmployeId: 0, SuperRole: "Super Admin"}
// }
// Copy slice to slice
@@ -77,12 +95,26 @@ func main() {
fmt.Printf("%#v \n", employees)
// []Employee{
// {Name: "Jinzhu", Age: 18, DoubleAge: 36, EmployeId: 0, SuperRule: "Super Admin"},
// {Name: "jinzhu 2", Age: 30, DoubleAge: 60, EmployeId: 0, SuperRule: "Super Dev"},
// {Name: "Jinzhu", Age: 18, Salary:0, DoubleAge: 36, EmployeId: 0, SuperRole: "Super Admin"},
// {Name: "jinzhu 2", Age: 30, Salary:0, DoubleAge: 60, EmployeId: 0, SuperRole: "Super Dev"},
// }
// Copy map to map
map1 := map[int]int{3: 6, 4: 8}
map2 := map[int32]int8{}
copier.Copy(&map2, map1)
fmt.Printf("%#v \n", map2)
// map[int32]int8{3:6, 4:8}
}
```
### Copy with Option
```go
copier.CopyWithOption(&to, &from, copier.Option{IgnoreEmpty: true, DeepCopy: true})
```
## Contributing
You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do.

View File

@@ -2,43 +2,206 @@ package copier
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
"unicode"
)
// These flags define options for tag handling
const (
// Denotes that a destination field must be copied to. If copying fails then a panic will ensue.
tagMust uint8 = 1 << iota
// Denotes that the program should not panic when the must flag is on and
// value is not copied. The program will return an error instead.
tagNoPanic
// Ignore a destination field from being copied to.
tagIgnore
// Denotes that the value as been copied
hasCopied
// Some default converter types for a nicer syntax
String string = ""
Bool bool = false
Int int = 0
Float32 float32 = 0
Float64 float64 = 0
)
// Option sets copy options
type Option struct {
// setting this value to true will ignore copying zero values of all the fields, including bools, as well as a
// struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
IgnoreEmpty bool
DeepCopy bool
Converters []TypeConverter
}
type TypeConverter struct {
SrcType interface{}
DstType interface{}
Fn func(src interface{}) (interface{}, error)
}
type converterPair struct {
SrcType reflect.Type
DstType reflect.Type
}
// Tag Flags
type flags struct {
BitFlags map[string]uint8
SrcNames tagNameMapping
DestNames tagNameMapping
}
// Field Tag name mapping
type tagNameMapping struct {
FieldNameToTag map[string]string
TagToFieldName map[string]string
}
// Copy copy things
func Copy(toValue interface{}, fromValue interface{}) (err error) {
return copier(toValue, fromValue, Option{})
}
// CopyWithOption copy with option
func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err error) {
return copier(toValue, fromValue, opt)
}
func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) {
var (
isSlice bool
amount = 1
from = indirect(reflect.ValueOf(fromValue))
to = indirect(reflect.ValueOf(toValue))
isSlice bool
amount = 1
from = indirect(reflect.ValueOf(fromValue))
to = indirect(reflect.ValueOf(toValue))
converters map[converterPair]TypeConverter
)
// save convertes into map for faster lookup
for i := range opt.Converters {
if converters == nil {
converters = make(map[converterPair]TypeConverter)
}
pair := converterPair{
SrcType: reflect.TypeOf(opt.Converters[i].SrcType),
DstType: reflect.TypeOf(opt.Converters[i].DstType),
}
converters[pair] = opt.Converters[i]
}
if !to.CanAddr() {
return errors.New("copy to value is unaddressable")
return ErrInvalidCopyDestination
}
// Return is from value is invalid
if !from.IsValid() {
return ErrInvalidCopyFrom
}
fromType, isPtrFrom := indirectType(from.Type())
toType, _ := indirectType(to.Type())
if fromType.Kind() == reflect.Interface {
fromType = reflect.TypeOf(from.Interface())
}
if toType.Kind() == reflect.Interface {
toType, _ = indirectType(reflect.TypeOf(to.Interface()))
oldTo := to
to = reflect.New(reflect.TypeOf(to.Interface())).Elem()
defer func() {
oldTo.Set(to)
}()
}
// Just set it if possible to assign for normal types
if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) {
if !isPtrFrom || !opt.DeepCopy {
to.Set(from.Convert(to.Type()))
} else {
fromCopy := reflect.New(from.Type())
fromCopy.Set(from.Elem())
to.Set(fromCopy.Convert(to.Type()))
}
return
}
fromType := indirectType(from.Type())
toType := indirectType(to.Type())
if from.Kind() != reflect.Slice && fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map {
if !fromType.Key().ConvertibleTo(toType.Key()) {
return ErrMapKeyNotMatch
}
// Just set it if possible to assign
// And need to do copy anyway if the type is struct
if fromType.Kind() != reflect.Struct && from.Type().AssignableTo(to.Type()) {
to.Set(from)
if to.IsNil() {
to.Set(reflect.MakeMapWithSize(toType, from.Len()))
}
for _, k := range from.MapKeys() {
toKey := indirect(reflect.New(toType.Key()))
if !set(toKey, k, opt.DeepCopy, converters) {
return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
}
elemType := toType.Elem()
if elemType.Kind() != reflect.Slice {
elemType, _ = indirectType(elemType)
}
toValue := indirect(reflect.New(elemType))
if !set(toValue, from.MapIndex(k), opt.DeepCopy, converters) {
if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
return err
}
}
for {
if elemType == toType.Elem() {
to.SetMapIndex(toKey, toValue)
break
}
elemType = reflect.PtrTo(elemType)
toValue = toValue.Addr()
}
}
return
}
if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) {
if to.IsNil() {
slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap())
to.Set(slice)
}
for i := 0; i < from.Len(); i++ {
if to.Len() < i+1 {
to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
}
if !set(to.Index(i), from.Index(i), opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
if err != nil {
continue
}
}
}
return
}
if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct {
// skip not supported type
return
}
if to.Kind() == reflect.Slice {
if from.Kind() == reflect.Slice || to.Kind() == reflect.Slice {
isSlice = true
if from.Kind() == reflect.Slice {
amount = from.Len()
@@ -62,31 +225,86 @@ func Copy(toValue interface{}, fromValue interface{}) (err error) {
dest = indirect(to)
}
destKind := dest.Kind()
initDest := false
if destKind == reflect.Interface {
initDest = true
dest = indirect(reflect.New(toType))
}
// Get tag options
flgs, err := getFlags(dest, source, toType, fromType)
if err != nil {
return err
}
// check source
if source.IsValid() {
copyUnexportedStructFields(dest, source)
// Copy from source field to dest field or method
fromTypeFields := deepFields(fromType)
//fmt.Printf("%#v", fromTypeFields)
// Copy from field to field or method
for _, field := range fromTypeFields {
name := field.Name
if fromField := source.FieldByName(name); fromField.IsValid() {
// has field
if toField := dest.FieldByName(name); toField.IsValid() {
// Get bit flags for field
fieldFlags, _ := flgs.BitFlags[name]
// Check if we should ignore copying
if (fieldFlags & tagIgnore) != 0 {
continue
}
srcFieldName, destFieldName := getFieldName(name, flgs)
if fromField := source.FieldByName(srcFieldName); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
// process for nested anonymous field
destFieldNotSet := false
if f, ok := dest.Type().FieldByName(destFieldName); ok {
for idx := range f.Index {
destField := dest.FieldByIndex(f.Index[:idx+1])
if destField.Kind() != reflect.Ptr {
continue
}
if !destField.IsNil() {
continue
}
if !destField.CanSet() {
destFieldNotSet = true
break
}
// destField is a nil pointer that can be set
newValue := reflect.New(destField.Type().Elem())
destField.Set(newValue)
}
}
if destFieldNotSet {
break
}
toField := dest.FieldByName(destFieldName)
if toField.IsValid() {
if toField.CanSet() {
if !set(toField, fromField) {
if err := Copy(toField.Addr().Interface(), fromField.Interface()); err != nil {
if !set(toField, fromField, opt.DeepCopy, converters) {
if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
return err
}
}
if fieldFlags != 0 {
// Note that a copy was made
flgs.BitFlags[name] = fieldFlags | hasCopied
}
}
} else {
// try to set to method
var toMethod reflect.Value
if dest.CanAddr() {
toMethod = dest.Addr().MethodByName(name)
toMethod = dest.Addr().MethodByName(destFieldName)
} else {
toMethod = dest.MethodByName(name)
toMethod = dest.MethodByName(destFieldName)
}
if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) {
@@ -96,53 +314,113 @@ func Copy(toValue interface{}, fromValue interface{}) (err error) {
}
}
// Copy from method to field
// Copy from from method to dest field
for _, field := range deepFields(toType) {
name := field.Name
srcFieldName, destFieldName := getFieldName(name, flgs)
var fromMethod reflect.Value
if source.CanAddr() {
fromMethod = source.Addr().MethodByName(name)
fromMethod = source.Addr().MethodByName(srcFieldName)
} else {
fromMethod = source.MethodByName(name)
fromMethod = source.MethodByName(srcFieldName)
}
if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 {
if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() {
if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) {
if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() {
values := fromMethod.Call([]reflect.Value{})
if len(values) >= 1 {
set(toField, values[0])
set(toField, values[0], opt.DeepCopy, converters)
}
}
}
}
}
if isSlice {
if isSlice && to.Kind() == reflect.Slice {
if dest.Addr().Type().AssignableTo(to.Type().Elem()) {
to.Set(reflect.Append(to, dest.Addr()))
if to.Len() < i+1 {
to.Set(reflect.Append(to, dest.Addr()))
} else {
if !set(to.Index(i), dest.Addr(), opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt)
if err != nil {
continue
}
}
}
} else if dest.Type().AssignableTo(to.Type().Elem()) {
to.Set(reflect.Append(to, dest))
if to.Len() < i+1 {
to.Set(reflect.Append(to, dest))
} else {
if !set(to.Index(i), dest, opt.DeepCopy, converters) {
// ignore error while copy slice element
err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt)
if err != nil {
continue
}
}
}
}
} else if initDest {
to.Set(dest)
}
err = checkBitFlags(flgs.BitFlags)
}
return
}
func deepFields(reflectType reflect.Type) []reflect.StructField {
var fields []reflect.StructField
if reflectType = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
for i := 0; i < reflectType.NumField(); i++ {
v := reflectType.Field(i)
if v.Anonymous {
fields = append(fields, deepFields(v.Type)...)
} else {
fields = append(fields, v)
}
}
func copyUnexportedStructFields(to, from reflect.Value) {
if from.Kind() != reflect.Struct || to.Kind() != reflect.Struct || !from.Type().AssignableTo(to.Type()) {
return
}
return fields
// create a shallow copy of 'to' to get all fields
tmp := indirect(reflect.New(to.Type()))
tmp.Set(from)
// revert exported fields
for i := 0; i < to.NumField(); i++ {
if tmp.Field(i).CanSet() {
tmp.Field(i).Set(to.Field(i))
}
}
to.Set(tmp)
}
func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
if !ignoreEmpty {
return false
}
return v.IsZero()
}
func deepFields(reflectType reflect.Type) []reflect.StructField {
if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
fields := make([]reflect.StructField, 0, reflectType.NumField())
for i := 0; i < reflectType.NumField(); i++ {
v := reflectType.Field(i)
// PkgPath is the package path that qualifies a lower case (unexported)
// field name. It is empty for upper case (exported) field names.
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
if v.PkgPath == "" {
fields = append(fields, v)
if v.Anonymous {
// also consider fields of anonymous fields as fields of the root
fields = append(fields, deepFields(v.Type)...)
}
}
}
return fields
}
return nil
}
func indirect(reflectValue reflect.Value) reflect.Value {
@@ -152,38 +430,268 @@ func indirect(reflectValue reflect.Value) reflect.Value {
return reflectValue
}
func indirectType(reflectType reflect.Type) reflect.Type {
func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
reflectType = reflectType.Elem()
isPtr = true
}
return reflectType
return reflectType, isPtr
}
func set(to, from reflect.Value) bool {
func set(to, from reflect.Value, deepCopy bool, converters map[converterPair]TypeConverter) bool {
if from.IsValid() {
if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil {
return false
} else if ok {
return true
}
if to.Kind() == reflect.Ptr {
//set `to` to nil if from is nil
// set `to` to nil if from is nil
if from.Kind() == reflect.Ptr && from.IsNil() {
to.Set(reflect.Zero(to.Type()))
return true
} else if to.IsNil() {
// `from` -> `to`
// sql.NullString -> *string
if fromValuer, ok := driverValuer(from); ok {
v, err := fromValuer.Value()
if err != nil {
return false
}
// if `from` is not valid do nothing with `to`
if v == nil {
return true
}
}
// allocate new `to` variable with default value (eg. *string -> new(string))
to.Set(reflect.New(to.Type().Elem()))
}
// depointer `to`
to = to.Elem()
}
if deepCopy {
toKind := to.Kind()
if toKind == reflect.Interface && to.IsNil() {
if reflect.TypeOf(from.Interface()) != nil {
to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem())
toKind = reflect.TypeOf(to.Interface()).Kind()
}
}
if from.Kind() == reflect.Ptr && from.IsNil() {
return true
}
if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
return false
}
}
if from.Type().ConvertibleTo(to.Type()) {
to.Set(from.Convert(to.Type()))
} else if scanner, ok := to.Addr().Interface().(sql.Scanner); ok {
err := scanner.Scan(from.Interface())
} else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok {
// `from` -> `to`
// *string -> sql.NullString
if from.Kind() == reflect.Ptr {
// if `from` is nil do nothing with `to`
if from.IsNil() {
return true
}
// depointer `from`
from = indirect(from)
}
// `from` -> `to`
// string -> sql.NullString
// set `to` by invoking method Scan(`from`)
err := toScanner.Scan(from.Interface())
if err != nil {
return false
}
} else if fromValuer, ok := driverValuer(from); ok {
// `from` -> `to`
// sql.NullString -> string
v, err := fromValuer.Value()
if err != nil {
return false
}
// if `from` is not valid do nothing with `to`
if v == nil {
return true
}
rv := reflect.ValueOf(v)
if rv.Type().AssignableTo(to.Type()) {
to.Set(rv)
}
} else if from.Kind() == reflect.Ptr {
return set(to, from.Elem())
return set(to, from.Elem(), deepCopy, converters)
} else {
return false
}
}
return true
}
// lookupAndCopyWithConverter looks up the type pair, on success the TypeConverter Fn func is called to copy src to dst field.
func lookupAndCopyWithConverter(to, from reflect.Value, converters map[converterPair]TypeConverter) (copied bool, err error) {
pair := converterPair{
SrcType: from.Type(),
DstType: to.Type(),
}
if cnv, ok := converters[pair]; ok {
result, err := cnv.Fn(from.Interface())
if err != nil {
return false, err
}
if result != nil {
to.Set(reflect.ValueOf(result))
} else {
// in case we've got a nil value to copy
to.Set(reflect.Zero(to.Type()))
}
return true, nil
}
return false, nil
}
// parseTags Parses struct tags and returns uint8 bit flags.
func parseTags(tag string) (flg uint8, name string, err error) {
for _, t := range strings.Split(tag, ",") {
switch t {
case "-":
flg = tagIgnore
return
case "must":
flg = flg | tagMust
case "nopanic":
flg = flg | tagNoPanic
default:
if unicode.IsUpper([]rune(t)[0]) {
name = strings.TrimSpace(t)
} else {
err = errors.New("copier field name tag must be start upper case")
}
}
}
return
}
// getTagFlags Parses struct tags for bit flags, field name.
func getFlags(dest, src reflect.Value, toType, fromType reflect.Type) (flags, error) {
flgs := flags{
BitFlags: map[string]uint8{},
SrcNames: tagNameMapping{
FieldNameToTag: map[string]string{},
TagToFieldName: map[string]string{},
},
DestNames: tagNameMapping{
FieldNameToTag: map[string]string{},
TagToFieldName: map[string]string{},
},
}
var toTypeFields, fromTypeFields []reflect.StructField
if dest.IsValid() {
toTypeFields = deepFields(toType)
}
if src.IsValid() {
fromTypeFields = deepFields(fromType)
}
// Get a list dest of tags
for _, field := range toTypeFields {
tags := field.Tag.Get("copier")
if tags != "" {
var name string
var err error
if flgs.BitFlags[field.Name], name, err = parseTags(tags); err != nil {
return flags{}, err
} else if name != "" {
flgs.DestNames.FieldNameToTag[field.Name] = name
flgs.DestNames.TagToFieldName[name] = field.Name
}
}
}
// Get a list source of tags
for _, field := range fromTypeFields {
tags := field.Tag.Get("copier")
if tags != "" {
var name string
var err error
if _, name, err = parseTags(tags); err != nil {
return flags{}, err
} else if name != "" {
flgs.SrcNames.FieldNameToTag[field.Name] = name
flgs.SrcNames.TagToFieldName[name] = field.Name
}
}
}
return flgs, nil
}
// checkBitFlags Checks flags for error or panic conditions.
func checkBitFlags(flagsList map[string]uint8) (err error) {
// Check flag conditions were met
for name, flgs := range flagsList {
if flgs&hasCopied == 0 {
switch {
case flgs&tagMust != 0 && flgs&tagNoPanic != 0:
err = fmt.Errorf("field %s has must tag but was not copied", name)
return
case flgs&(tagMust) != 0:
panic(fmt.Sprintf("Field %s has must tag but was not copied", name))
}
}
}
return
}
func getFieldName(fieldName string, flgs flags) (srcFieldName string, destFieldName string) {
// get dest field name
if srcTagName, ok := flgs.SrcNames.FieldNameToTag[fieldName]; ok {
destFieldName = srcTagName
if destTagName, ok := flgs.DestNames.TagToFieldName[srcTagName]; ok {
destFieldName = destTagName
}
} else {
if destTagName, ok := flgs.DestNames.TagToFieldName[fieldName]; ok {
destFieldName = destTagName
}
}
if destFieldName == "" {
destFieldName = fieldName
}
// get source field name
if destTagName, ok := flgs.DestNames.FieldNameToTag[fieldName]; ok {
srcFieldName = destTagName
if srcField, ok := flgs.SrcNames.TagToFieldName[destTagName]; ok {
srcFieldName = srcField
}
} else {
if srcField, ok := flgs.SrcNames.TagToFieldName[fieldName]; ok {
srcFieldName = srcField
}
}
if srcFieldName == "" {
srcFieldName = fieldName
}
return
}
func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {
if !v.CanAddr() {
i, ok = v.Interface().(driver.Valuer)
return
}
i, ok = v.Addr().Interface().(driver.Valuer)
return
}

10
vendor/github.com/jinzhu/copier/errors.go generated vendored Normal file
View File

@@ -0,0 +1,10 @@
package copier
import "errors"
var (
ErrInvalidCopyDestination = errors.New("copy destination is invalid")
ErrInvalidCopyFrom = errors.New("copy from is invalid")
ErrMapKeyNotMatch = errors.New("map's key type doesn't match")
ErrNotSupported = errors.New("not supported")
)

View File

@@ -1,23 +0,0 @@
box: golang
build:
steps:
- setup-go-workspace
# Gets the dependencies
- script:
name: go get
code: |
go get
# Build the project
- script:
name: go build
code: |
go build ./...
# Test the project
- script:
name: go test
code: |
go test ./...

View File

@@ -1,6 +1,12 @@
Changes
=======
---
16:05:02
Thursday, July 2, 2020
Change license from the WTFPL to the Unlicense due to pkg.go.dev restriction.
---
15:39:40
Wednesday, April 17, 2019

View File

@@ -1,13 +1,24 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
This is free and unencumbered software released into the public domain.
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
0. You just DO WHAT THE FUCK YOU WANT TO.
For more information, please refer to <http://unlicense.org/>

View File

@@ -1,8 +1,8 @@
Aurora
======
[![GoDoc](https://godoc.org/github.com/logrusorgru/aurora?status.svg)](https://godoc.org/github.com/logrusorgru/aurora)
[![WTFPL License](https://img.shields.io/badge/license-wtfpl-blue.svg)](http://www.wtfpl.net/about/)
[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/logrusorgru/aurora?tab=doc)
[![Unlicense](https://img.shields.io/badge/license-unlicense-blue.svg)](http://unlicense.org/)
[![Build Status](https://travis-ci.org/logrusorgru/aurora.svg)](https://travis-ci.org/logrusorgru/aurora)
[![Coverage Status](https://coveralls.io/repos/logrusorgru/aurora/badge.svg?branch=master)](https://coveralls.io/r/logrusorgru/aurora?branch=master)
[![GoReportCard](https://goreportcard.com/badge/logrusorgru/aurora)](https://goreportcard.com/report/logrusorgru/aurora)
@@ -236,7 +236,7 @@ Methods `Index` and `BgIndex` implements 8-bit colors.
+ red
+ green
+ yellow (brown)
+ blue
+ blue
+ magenta
+ cyan
+ white
@@ -295,7 +295,7 @@ The obvious workaround is `Red(fmt.Sprintf("%T", some))`
The Aurora provides ANSI colors only, so there is no support for Windows. That said, there are workarounds available.
Check out these comments to learn more:
- [Using go-colrable](https://github.com/logrusorgru/aurora/issues/2#issuecomment-299014211).
- [Using go-colorable](https://github.com/logrusorgru/aurora/issues/2#issuecomment-299014211).
- [Using registry for Windows 10](https://github.com/logrusorgru/aurora/issues/10#issue-476361247).
### TTY
@@ -306,10 +306,9 @@ on colors for a terminal only, and turn them off for a file.
### Licensing
Copyright &copy; 2016-2019 The Aurora Authors. This work is free.
Copyright &copy; 2016-2020 The Aurora Authors. This work is free.
It comes without any warranty, to the extent permitted by applicable
law. You can redistribute it and/or modify it under the terms of the
Do What The Fuck You Want To Public License, Version 2, as published
by Sam Hocevar. See the LICENSE file for more details.
the Unlicense. See the LICENSE file for more details.

View File

@@ -1,26 +1,36 @@
//
// Copyright (c) 2016-2019 The Aurora Authors. All rights reserved.
// Copyright (c) 2016-2020 The Aurora Authors. All rights reserved.
// This program is free software. It comes without any warranty,
// to the extent permitted by applicable law. You can redistribute
// it and/or modify it under the terms of the Do What The Fuck You
// Want To Public License, Version 2, as published by Sam Hocevar.
// See LICENSE file for more details or see below.
// it and/or modify it under the terms of the Unlicense. See LICENSE
// file for more details or see below.
//
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// Version 2, December 2004
// This is free and unencumbered software released into the public domain.
//
// Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
// Anyone is free to copy, modify, publish, use, compile, sell, or
// distribute this software, either in source code form or as a compiled
// binary, for any purpose, commercial or non-commercial, and by any
// means.
//
// Everyone is permitted to copy and distribute verbatim or modified
// copies of this license document, and changing it is allowed as long
// as the name is changed.
// In jurisdictions that recognize copyright laws, the author or authors
// of this software dedicate any and all copyright interest in the
// software to the public domain. We make this dedication for the benefit
// of the public at large and to the detriment of our heirs and
// successors. We intend this dedication to be an overt act of
// relinquishment in perpetuity of all present and future rights to this
// software under copyright law.
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
// 0. You just DO WHAT THE FUCK YOU WANT TO.
// For more information, please refer to <http://unlicense.org/>
//
// Package aurora implements ANSI-colors

View File

@@ -1,26 +1,36 @@
//
// Copyright (c) 2016-2019 The Aurora Authors. All rights reserved.
// Copyright (c) 2016-2020 The Aurora Authors. All rights reserved.
// This program is free software. It comes without any warranty,
// to the extent permitted by applicable law. You can redistribute
// it and/or modify it under the terms of the Do What The Fuck You
// Want To Public License, Version 2, as published by Sam Hocevar.
// See LICENSE file for more details or see below.
// it and/or modify it under the terms of the Unlicense. See LICENSE
// file for more details or see below.
//
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// Version 2, December 2004
// This is free and unencumbered software released into the public domain.
//
// Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
// Anyone is free to copy, modify, publish, use, compile, sell, or
// distribute this software, either in source code form or as a compiled
// binary, for any purpose, commercial or non-commercial, and by any
// means.
//
// Everyone is permitted to copy and distribute verbatim or modified
// copies of this license document, and changing it is allowed as long
// as the name is changed.
// In jurisdictions that recognize copyright laws, the author or authors
// of this software dedicate any and all copyright interest in the
// software to the public domain. We make this dedication for the benefit
// of the public at large and to the detriment of our heirs and
// successors. We intend this dedication to be an overt act of
// relinquishment in perpetuity of all present and future rights to this
// software under copyright law.
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
// 0. You just DO WHAT THE FUCK YOU WANT TO.
// For more information, please refer to <http://unlicense.org/>
//
package aurora

View File

@@ -1,26 +1,36 @@
//
// Copyright (c) 2016-2019 The Aurora Authors. All rights reserved.
// Copyright (c) 2016-2020 The Aurora Authors. All rights reserved.
// This program is free software. It comes without any warranty,
// to the extent permitted by applicable law. You can redistribute
// it and/or modify it under the terms of the Do What The Fuck You
// Want To Public License, Version 2, as published by Sam Hocevar.
// See LICENSE file for more details or see below.
// it and/or modify it under the terms of the Unlicense. See LICENSE
// file for more details or see below.
//
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// Version 2, December 2004
// This is free and unencumbered software released into the public domain.
//
// Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
// Anyone is free to copy, modify, publish, use, compile, sell, or
// distribute this software, either in source code form or as a compiled
// binary, for any purpose, commercial or non-commercial, and by any
// means.
//
// Everyone is permitted to copy and distribute verbatim or modified
// copies of this license document, and changing it is allowed as long
// as the name is changed.
// In jurisdictions that recognize copyright laws, the author or authors
// of this software dedicate any and all copyright interest in the
// software to the public domain. We make this dedication for the benefit
// of the public at large and to the detriment of our heirs and
// successors. We intend this dedication to be an overt act of
// relinquishment in perpetuity of all present and future rights to this
// software under copyright law.
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
// 0. You just DO WHAT THE FUCK YOU WANT TO.
// For more information, please refer to <http://unlicense.org/>
//
package aurora

View File

@@ -1,26 +1,36 @@
//
// Copyright (c) 2016-2019 The Aurora Authors. All rights reserved.
// Copyright (c) 2016-2020 The Aurora Authors. All rights reserved.
// This program is free software. It comes without any warranty,
// to the extent permitted by applicable law. You can redistribute
// it and/or modify it under the terms of the Do What The Fuck You
// Want To Public License, Version 2, as published by Sam Hocevar.
// See LICENSE file for more details or see below.
// it and/or modify it under the terms of the Unlicense. See LICENSE
// file for more details or see below.
//
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// Version 2, December 2004
// This is free and unencumbered software released into the public domain.
//
// Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
// Anyone is free to copy, modify, publish, use, compile, sell, or
// distribute this software, either in source code form or as a compiled
// binary, for any purpose, commercial or non-commercial, and by any
// means.
//
// Everyone is permitted to copy and distribute verbatim or modified
// copies of this license document, and changing it is allowed as long
// as the name is changed.
// In jurisdictions that recognize copyright laws, the author or authors
// of this software dedicate any and all copyright interest in the
// software to the public domain. We make this dedication for the benefit
// of the public at large and to the detriment of our heirs and
// successors. We intend this dedication to be an overt act of
// relinquishment in perpetuity of all present and future rights to this
// software under copyright law.
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
// 0. You just DO WHAT THE FUCK YOU WANT TO.
// For more information, please refer to <http://unlicense.org/>
//
package aurora

View File

@@ -1,26 +1,36 @@
//
// Copyright (c) 2016-2019 The Aurora Authors. All rights reserved.
// Copyright (c) 2016-2020 The Aurora Authors. All rights reserved.
// This program is free software. It comes without any warranty,
// to the extent permitted by applicable law. You can redistribute
// it and/or modify it under the terms of the Do What The Fuck You
// Want To Public License, Version 2, as published by Sam Hocevar.
// See LICENSE file for more details or see below.
// it and/or modify it under the terms of the Unlicense. See LICENSE
// file for more details or see below.
//
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// Version 2, December 2004
// This is free and unencumbered software released into the public domain.
//
// Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
// Anyone is free to copy, modify, publish, use, compile, sell, or
// distribute this software, either in source code form or as a compiled
// binary, for any purpose, commercial or non-commercial, and by any
// means.
//
// Everyone is permitted to copy and distribute verbatim or modified
// copies of this license document, and changing it is allowed as long
// as the name is changed.
// In jurisdictions that recognize copyright laws, the author or authors
// of this software dedicate any and all copyright interest in the
// software to the public domain. We make this dedication for the benefit
// of the public at large and to the detriment of our heirs and
// successors. We intend this dedication to be an overt act of
// relinquishment in perpetuity of all present and future rights to this
// software under copyright law.
//
// DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
// TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
// OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
// 0. You just DO WHAT THE FUCK YOU WANT TO.
// For more information, please refer to <http://unlicense.org/>
//
package aurora

14
vendor/github.com/rs/xid/README.md generated vendored
View File

@@ -2,9 +2,9 @@
[![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/xid) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/xid/master/LICENSE) [![Build Status](https://travis-ci.org/rs/xid.svg?branch=master)](https://travis-ci.org/rs/xid) [![Coverage](http://gocover.io/_badge/github.com/rs/xid)](http://gocover.io/github.com/rs/xid)
Package xid is a globally unique id generator library, ready to be used safely directly in your server code.
Package xid is a globally unique id generator library, ready to safely be used directly in your server code.
Xid is using Mongo Object ID algorithm to generate globally unique ids with a different serialization (base64) to make it shorter when transported as a string:
Xid uses the Mongo Object ID algorithm to generate globally unique ids with a different serialization (base64) to make it shorter when transported as a string:
https://docs.mongodb.org/manual/reference/object-id/
- 4-byte value representing the seconds since the Unix epoch,
@@ -33,7 +33,7 @@ is required so it can be used directly in server's code.
|-------------|-------------|----------------|----------------
| [UUID] | 16 bytes | 36 chars | configuration free, not sortable
| [shortuuid] | 16 bytes | 22 chars | configuration free, not sortable
| [Snowflake] | 8 bytes | up to 20 chars | needs machin/DC configuration, needs central server, sortable
| [Snowflake] | 8 bytes | up to 20 chars | needs machine/DC configuration, needs central server, sortable
| [MongoID] | 12 bytes | 24 chars | configuration free, sortable
| xid | 12 bytes | 20 chars | configuration free, sortable
@@ -57,7 +57,7 @@ Best used with [zerolog](https://github.com/rs/zerolog)'s
Notes:
- Xid is dependent on the system time, a monotonic counter and so is not cryptographically secure. If unpredictability of IDs is important, you should not use Xids. It is worth noting that most of the other UUID like implementations are also not cryptographically secure. You shoud use libraries that rely on cryptographically secure sources (like /dev/urandom on unix, crypto/rand in golang), if you want a truly random ID generator.
- Xid is dependent on the system time, a monotonic counter and so is not cryptographically secure. If unpredictability of IDs is important, you should not use Xids. It is worth noting that most other UUID-like implementations are also not cryptographically secure. You should use libraries that rely on cryptographically secure sources (like /dev/urandom on unix, crypto/rand in golang), if you want a truly random ID generator.
References:
@@ -66,6 +66,10 @@ References:
- https://blog.twitter.com/2010/announcing-snowflake
- Python port by [Graham Abbott](https://github.com/graham): https://github.com/graham/python_xid
- Scala port by [Egor Kolotaev](https://github.com/kolotaev): https://github.com/kolotaev/ride
- Rust port by [Jérôme Renard](https://github.com/jeromer/): https://github.com/jeromer/libxid
- Ruby port by [Valar](https://github.com/valarpirai/): https://github.com/valarpirai/ruby_xid
- Java port by [0xShamil](https://github.com/0xShamil/): https://github.com/0xShamil/java-xid
- Dart port by [Peter Bwire](https://github.com/pitabwire): https://pub.dev/packages/xid
## Install
@@ -105,7 +109,7 @@ BenchmarkUUIDv4-2 1000000 1427 ns/op 64 B/op 2 allocs/op
BenchmarkUUIDv4-4 1000000 1452 ns/op 64 B/op 2 allocs/op
```
Note: UUIDv1 requires a global lock, hence the performence degrading as we add more CPUs.
Note: UUIDv1 requires a global lock, hence the performance degradation as we add more CPUs.
## Licenses

11
vendor/github.com/rs/xid/error.go generated vendored Normal file
View File

@@ -0,0 +1,11 @@
package xid
const (
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
ErrInvalidID strErr = "xid: invalid ID"
)
// strErr allows declaring errors as constants.
type strErr string
func (err strErr) Error() string { return string(err) }

1
vendor/github.com/rs/xid/go.mod generated vendored
View File

@@ -1 +0,0 @@
module github.com/rs/xid

View File

@@ -5,6 +5,9 @@ package xid
import "io/ioutil"
func readPlatformMachineID() (string, error) {
b, err := ioutil.ReadFile("/sys/class/dmi/id/product_uuid")
b, err := ioutil.ReadFile("/etc/machine-id")
if err != nil || len(b) == 0 {
b, err = ioutil.ReadFile("/sys/class/dmi/id/product_uuid")
}
return string(b), err
}

107
vendor/github.com/rs/xid/id.go generated vendored
View File

@@ -47,7 +47,6 @@ import (
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
@@ -55,6 +54,7 @@ import (
"sort"
"sync/atomic"
"time"
"unsafe"
)
// Code inspired from mgo/bson ObjectId
@@ -72,9 +72,6 @@ const (
)
var (
// ErrInvalidID is returned when trying to unmarshal an invalid ID
ErrInvalidID = errors.New("xid: invalid ID")
// objectIDCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
// This id is initialized with a random value.
@@ -177,7 +174,13 @@ func FromString(id string) (ID, error) {
func (id ID) String() string {
text := make([]byte, encodedLen)
encode(text, id[:])
return string(text)
return *(*string)(unsafe.Pointer(&text))
}
// Encode encodes the id using base32 encoding, writing 20 bytes to dst and return it.
func (id ID) Encode(dst []byte) []byte {
encode(dst, id[:])
return dst
}
// MarshalText implements encoding/text TextMarshaler interface
@@ -192,32 +195,37 @@ func (id ID) MarshalJSON() ([]byte, error) {
if id.IsNil() {
return []byte("null"), nil
}
text, err := id.MarshalText()
return []byte(`"` + string(text) + `"`), err
text := make([]byte, encodedLen+2)
encode(text[1:encodedLen+1], id[:])
text[0], text[encodedLen+1] = '"', '"'
return text, nil
}
// encode by unrolling the stdlib base32 algorithm + removing all safe checks
func encode(dst, id []byte) {
dst[0] = encoding[id[0]>>3]
dst[1] = encoding[(id[1]>>6)&0x1F|(id[0]<<2)&0x1F]
dst[2] = encoding[(id[1]>>1)&0x1F]
dst[3] = encoding[(id[2]>>4)&0x1F|(id[1]<<4)&0x1F]
dst[4] = encoding[id[3]>>7|(id[2]<<1)&0x1F]
dst[5] = encoding[(id[3]>>2)&0x1F]
dst[6] = encoding[id[4]>>5|(id[3]<<3)&0x1F]
dst[7] = encoding[id[4]&0x1F]
dst[8] = encoding[id[5]>>3]
dst[9] = encoding[(id[6]>>6)&0x1F|(id[5]<<2)&0x1F]
dst[10] = encoding[(id[6]>>1)&0x1F]
dst[11] = encoding[(id[7]>>4)&0x1F|(id[6]<<4)&0x1F]
dst[12] = encoding[id[8]>>7|(id[7]<<1)&0x1F]
dst[13] = encoding[(id[8]>>2)&0x1F]
dst[14] = encoding[(id[9]>>5)|(id[8]<<3)&0x1F]
dst[15] = encoding[id[9]&0x1F]
dst[16] = encoding[id[10]>>3]
dst[17] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
dst[18] = encoding[(id[11]>>1)&0x1F]
_ = dst[19]
_ = id[11]
dst[19] = encoding[(id[11]<<4)&0x1F]
dst[18] = encoding[(id[11]>>1)&0x1F]
dst[17] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
dst[16] = encoding[id[10]>>3]
dst[15] = encoding[id[9]&0x1F]
dst[14] = encoding[(id[9]>>5)|(id[8]<<3)&0x1F]
dst[13] = encoding[(id[8]>>2)&0x1F]
dst[12] = encoding[id[8]>>7|(id[7]<<1)&0x1F]
dst[11] = encoding[(id[7]>>4)&0x1F|(id[6]<<4)&0x1F]
dst[10] = encoding[(id[6]>>1)&0x1F]
dst[9] = encoding[(id[6]>>6)&0x1F|(id[5]<<2)&0x1F]
dst[8] = encoding[id[5]>>3]
dst[7] = encoding[id[4]&0x1F]
dst[6] = encoding[id[4]>>5|(id[3]<<3)&0x1F]
dst[5] = encoding[(id[3]>>2)&0x1F]
dst[4] = encoding[id[3]>>7|(id[2]<<1)&0x1F]
dst[3] = encoding[(id[2]>>4)&0x1F|(id[1]<<4)&0x1F]
dst[2] = encoding[(id[1]>>1)&0x1F]
dst[1] = encoding[(id[1]>>6)&0x1F|(id[0]<<2)&0x1F]
dst[0] = encoding[id[0]>>3]
}
// UnmarshalText implements encoding/text TextUnmarshaler interface
@@ -230,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
return ErrInvalidID
}
}
decode(id, text)
if !decode(id, text) {
return ErrInvalidID
}
return nil
}
@@ -241,23 +251,40 @@ func (id *ID) UnmarshalJSON(b []byte) error {
*id = nilID
return nil
}
// Check the slice length to prevent panic on passing it to UnmarshalText()
if len(b) < 2 {
return ErrInvalidID
}
return id.UnmarshalText(b[1 : len(b)-1])
}
// decode by unrolling the stdlib base32 algorithm + removing all safe checks
func decode(id *ID, src []byte) {
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[3] = dec[src[4]]<<7 | dec[src[5]]<<2 | dec[src[6]]>>3
id[4] = dec[src[6]]<<5 | dec[src[7]]
id[5] = dec[src[8]]<<3 | dec[src[9]]>>2
id[6] = dec[src[9]]<<6 | dec[src[10]]<<1 | dec[src[11]]>>4
id[7] = dec[src[11]]<<4 | dec[src[12]]>>1
id[8] = dec[src[12]]<<7 | dec[src[13]]<<2 | dec[src[14]]>>3
id[9] = dec[src[14]]<<5 | dec[src[15]]
id[10] = dec[src[16]]<<3 | dec[src[17]]>>2
// decode by unrolling the stdlib base32 algorithm + customized safe check.
func decode(id *ID, src []byte) bool {
_ = src[19]
_ = id[11]
id[11] = dec[src[17]]<<6 | dec[src[18]]<<1 | dec[src[19]]>>4
id[10] = dec[src[16]]<<3 | dec[src[17]]>>2
id[9] = dec[src[14]]<<5 | dec[src[15]]
id[8] = dec[src[12]]<<7 | dec[src[13]]<<2 | dec[src[14]]>>3
id[7] = dec[src[11]]<<4 | dec[src[12]]>>1
id[6] = dec[src[9]]<<6 | dec[src[10]]<<1 | dec[src[11]]>>4
id[5] = dec[src[8]]<<3 | dec[src[9]]>>2
id[4] = dec[src[6]]<<5 | dec[src[7]]
id[3] = dec[src[4]]<<7 | dec[src[5]]<<2 | dec[src[6]]>>3
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
// Validate that there are no discarer bits (padding) in src that would
// cause the string-encoded id not to equal src.
var check [4]byte
check[3] = encoding[(id[11]<<4)&0x1F]
check[2] = encoding[(id[11]>>1)&0x1F]
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
check[0] = encoding[id[10]>>3]
return bytes.Equal([]byte(src[16:20]), check[:])
}
// Time returns the timestamp part of the id.

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2012-2018 Mat Ryer and Tyler Bunnell
Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -0,0 +1,436 @@
package assert
import (
"fmt"
"reflect"
"time"
)
type CompareType int
const (
compareLess CompareType = iota - 1
compareEqual
compareGreater
)
var (
intType = reflect.TypeOf(int(1))
int8Type = reflect.TypeOf(int8(1))
int16Type = reflect.TypeOf(int16(1))
int32Type = reflect.TypeOf(int32(1))
int64Type = reflect.TypeOf(int64(1))
uintType = reflect.TypeOf(uint(1))
uint8Type = reflect.TypeOf(uint8(1))
uint16Type = reflect.TypeOf(uint16(1))
uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1))
float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
obj1Value := reflect.ValueOf(obj1)
obj2Value := reflect.ValueOf(obj2)
// throughout this switch we try and avoid calling .Convert() if possible,
// as this has a pretty big performance impact
switch kind {
case reflect.Int:
{
intobj1, ok := obj1.(int)
if !ok {
intobj1 = obj1Value.Convert(intType).Interface().(int)
}
intobj2, ok := obj2.(int)
if !ok {
intobj2 = obj2Value.Convert(intType).Interface().(int)
}
if intobj1 > intobj2 {
return compareGreater, true
}
if intobj1 == intobj2 {
return compareEqual, true
}
if intobj1 < intobj2 {
return compareLess, true
}
}
case reflect.Int8:
{
int8obj1, ok := obj1.(int8)
if !ok {
int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
}
int8obj2, ok := obj2.(int8)
if !ok {
int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
}
if int8obj1 > int8obj2 {
return compareGreater, true
}
if int8obj1 == int8obj2 {
return compareEqual, true
}
if int8obj1 < int8obj2 {
return compareLess, true
}
}
case reflect.Int16:
{
int16obj1, ok := obj1.(int16)
if !ok {
int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
}
int16obj2, ok := obj2.(int16)
if !ok {
int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
}
if int16obj1 > int16obj2 {
return compareGreater, true
}
if int16obj1 == int16obj2 {
return compareEqual, true
}
if int16obj1 < int16obj2 {
return compareLess, true
}
}
case reflect.Int32:
{
int32obj1, ok := obj1.(int32)
if !ok {
int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
}
int32obj2, ok := obj2.(int32)
if !ok {
int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
}
if int32obj1 > int32obj2 {
return compareGreater, true
}
if int32obj1 == int32obj2 {
return compareEqual, true
}
if int32obj1 < int32obj2 {
return compareLess, true
}
}
case reflect.Int64:
{
int64obj1, ok := obj1.(int64)
if !ok {
int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
}
int64obj2, ok := obj2.(int64)
if !ok {
int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
}
if int64obj1 > int64obj2 {
return compareGreater, true
}
if int64obj1 == int64obj2 {
return compareEqual, true
}
if int64obj1 < int64obj2 {
return compareLess, true
}
}
case reflect.Uint:
{
uintobj1, ok := obj1.(uint)
if !ok {
uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
}
uintobj2, ok := obj2.(uint)
if !ok {
uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
}
if uintobj1 > uintobj2 {
return compareGreater, true
}
if uintobj1 == uintobj2 {
return compareEqual, true
}
if uintobj1 < uintobj2 {
return compareLess, true
}
}
case reflect.Uint8:
{
uint8obj1, ok := obj1.(uint8)
if !ok {
uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
}
uint8obj2, ok := obj2.(uint8)
if !ok {
uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
}
if uint8obj1 > uint8obj2 {
return compareGreater, true
}
if uint8obj1 == uint8obj2 {
return compareEqual, true
}
if uint8obj1 < uint8obj2 {
return compareLess, true
}
}
case reflect.Uint16:
{
uint16obj1, ok := obj1.(uint16)
if !ok {
uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
}
uint16obj2, ok := obj2.(uint16)
if !ok {
uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
}
if uint16obj1 > uint16obj2 {
return compareGreater, true
}
if uint16obj1 == uint16obj2 {
return compareEqual, true
}
if uint16obj1 < uint16obj2 {
return compareLess, true
}
}
case reflect.Uint32:
{
uint32obj1, ok := obj1.(uint32)
if !ok {
uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
}
uint32obj2, ok := obj2.(uint32)
if !ok {
uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
}
if uint32obj1 > uint32obj2 {
return compareGreater, true
}
if uint32obj1 == uint32obj2 {
return compareEqual, true
}
if uint32obj1 < uint32obj2 {
return compareLess, true
}
}
case reflect.Uint64:
{
uint64obj1, ok := obj1.(uint64)
if !ok {
uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
}
uint64obj2, ok := obj2.(uint64)
if !ok {
uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
}
if uint64obj1 > uint64obj2 {
return compareGreater, true
}
if uint64obj1 == uint64obj2 {
return compareEqual, true
}
if uint64obj1 < uint64obj2 {
return compareLess, true
}
}
case reflect.Float32:
{
float32obj1, ok := obj1.(float32)
if !ok {
float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
}
float32obj2, ok := obj2.(float32)
if !ok {
float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
}
if float32obj1 > float32obj2 {
return compareGreater, true
}
if float32obj1 == float32obj2 {
return compareEqual, true
}
if float32obj1 < float32obj2 {
return compareLess, true
}
}
case reflect.Float64:
{
float64obj1, ok := obj1.(float64)
if !ok {
float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
}
float64obj2, ok := obj2.(float64)
if !ok {
float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
}
if float64obj1 > float64obj2 {
return compareGreater, true
}
if float64obj1 == float64obj2 {
return compareEqual, true
}
if float64obj1 < float64obj2 {
return compareLess, true
}
}
case reflect.String:
{
stringobj1, ok := obj1.(string)
if !ok {
stringobj1 = obj1Value.Convert(stringType).Interface().(string)
}
stringobj2, ok := obj2.(string)
if !ok {
stringobj2 = obj2Value.Convert(stringType).Interface().(string)
}
if stringobj1 > stringobj2 {
return compareGreater, true
}
if stringobj1 == stringobj2 {
return compareEqual, true
}
if stringobj1 < stringobj2 {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
}
return compareEqual, false
}
// Greater asserts that the first element is greater than the second
//
// assert.Greater(t, 2, 1)
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
//
// assert.GreaterOrEqual(t, 2, 1)
// assert.GreaterOrEqual(t, 2, 2)
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
//
// assert.Less(t, 1, 2)
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
//
// assert.LessOrEqual(t, 1, 2)
// assert.LessOrEqual(t, 2, 2)
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
//
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
//
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
compareResult, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
if !containsValue(allowedComparesResults, compareResult) {
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
}
return true
}
func containsValue(values []CompareType, value CompareType) bool {
for _, v := range values {
if v == value {
return true
}
}
return false
}

View File

@@ -0,0 +1,16 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatability
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@@ -0,0 +1,16 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@@ -32,7 +32,8 @@ func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args
return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
}
// DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
// DirExistsf checks whether a directory exists in the given path. It also fails
// if the path is a file rather a directory or there is an error checking whether it exists.
func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -92,7 +93,7 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args
// EqualValuesf asserts that two objects are equal or convertable to the same types
// and equal.
//
// assert.EqualValuesf(t, uint32(123, "error message %s", "formatted"), int32(123))
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -113,6 +114,36 @@ func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
return Error(t, err, append([]interface{}{msg}, args...)...)
}
// ErrorAsf asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value.
// This is a wrapper for errors.As.
func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
}
// Eventuallyf asserts that given condition will be met in waitFor time,
// periodically checking target function each tick.
//
@@ -126,7 +157,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick
// Exactlyf asserts that two objects are equal in value and type.
//
// assert.Exactlyf(t, int32(123, "error message %s", "formatted"), int64(123))
// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted")
func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -160,7 +191,8 @@ func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
return False(t, value, append([]interface{}{msg}, args...)...)
}
// FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
// FileExistsf checks whether a file exists in the given path. It also fails if
// the path points to a directory or there is an error when trying to check the file.
func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -171,7 +203,7 @@ func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool
// Greaterf asserts that the first element is greater than the second
//
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
// assert.Greaterf(t, float64(2, "error message %s", "formatted"), float64(1))
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@@ -223,7 +255,7 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u
//
// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
// Returns whether the assertion was successful (true) or not (false).
func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -235,7 +267,7 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string,
//
// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
// Returns whether the assertion was successful (true) or not (false).
func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -243,6 +275,18 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri
return HTTPRedirect(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
}
// HTTPStatusCodef asserts that a specified handler returns a specified status code.
//
// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPStatusCode(t, handler, method, url, values, statuscode, append([]interface{}{msg}, args...)...)
}
// HTTPSuccessf asserts that a specified handler returns a success status code.
//
// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted")
@@ -257,7 +301,7 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin
// Implementsf asserts that an object is implemented by the specified interface.
//
// assert.Implementsf(t, (*MyInterface, "error message %s", "formatted")(nil), new(MyObject))
// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -267,7 +311,7 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms
// InDeltaf asserts that the two numerals are within delta of each other.
//
// assert.InDeltaf(t, math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01)
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -307,6 +351,54 @@ func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsil
return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...)
}
// IsDecreasingf asserts that the collection is decreasing
//
// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsDecreasing(t, object, append([]interface{}{msg}, args...)...)
}
// IsIncreasingf asserts that the collection is increasing
//
// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsIncreasing(t, object, append([]interface{}{msg}, args...)...)
}
// IsNonDecreasingf asserts that the collection is not decreasing
//
// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsNonDecreasing(t, object, append([]interface{}{msg}, args...)...)
}
// IsNonIncreasingf asserts that the collection is not increasing
//
// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...)
}
// IsTypef asserts that the specified objects are of the same type.
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@@ -325,14 +417,6 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int
return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return YAMLEq(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// Lenf asserts that the specified object has specific length.
// Lenf also fails if the object has a type that len() not accept.
//
@@ -347,7 +431,7 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf
// Lessf asserts that the first element is less than the second
//
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
// assert.Lessf(t, float64(1, "error message %s", "formatted"), float64(2))
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@@ -369,6 +453,28 @@ func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args .
return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
}
// Negativef asserts that the specified element is negative
//
// assert.Negativef(t, -1, "error message %s", "formatted")
// assert.Negativef(t, -1.23, "error message %s", "formatted")
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Negative(t, e, append([]interface{}{msg}, args...)...)
}
// Neverf asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Never(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
}
// Nilf asserts that the specified object is nil.
//
// assert.Nilf(t, err, "error message %s", "formatted")
@@ -379,6 +485,15 @@ func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool
return Nil(t, object, append([]interface{}{msg}, args...)...)
}
// NoDirExistsf checks whether a directory does not exist in the given path.
// It fails if the path points to an existing _directory_ only.
func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NoDirExists(t, path, append([]interface{}{msg}, args...)...)
}
// NoErrorf asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
@@ -392,6 +507,15 @@ func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
return NoError(t, err, append([]interface{}{msg}, args...)...)
}
// NoFileExistsf checks whether a file does not exist in a given path. It fails
// if the path points to an existing _file_ only.
func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NoFileExists(t, path, append([]interface{}{msg}, args...)...)
}
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
@@ -431,6 +555,25 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string,
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
//
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil.
//
// assert.NotNilf(t, err, "error message %s", "formatted")
@@ -453,7 +596,7 @@ func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bo
// NotRegexpf asserts that a specified regexp does not match a string.
//
// assert.NotRegexpf(t, regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting")
// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted")
// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted")
func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@@ -462,6 +605,19 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ..
return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...)
}
// NotSamef asserts that two pointers do not reference the same object.
//
// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted")
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
@@ -491,6 +647,18 @@ func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool
return Panics(t, f, append([]interface{}{msg}, args...)...)
}
// PanicsWithErrorf asserts that the code inside the specified PanicTestFunc
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return PanicsWithError(t, errString, f, append([]interface{}{msg}, args...)...)
}
// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
@@ -502,9 +670,20 @@ func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg str
return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...)
}
// Positivef asserts that the specified element is positive
//
// assert.Positivef(t, 1, "error message %s", "formatted")
// assert.Positivef(t, 1.23, "error message %s", "formatted")
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Positive(t, e, append([]interface{}{msg}, args...)...)
}
// Regexpf asserts that a specified regexp matches a string.
//
// assert.Regexpf(t, regexp.MustCompile("start", "error message %s", "formatted"), "it's starting")
// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted")
// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted")
func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
@@ -557,6 +736,14 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func YAMLEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return YAMLEq(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// Zerof asserts that i is the zero value for its type.
func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {

View File

@@ -53,7 +53,8 @@ func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string,
return Containsf(a.t, s, contains, msg, args...)
}
// DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
// DirExists checks whether a directory exists in the given path. It also fails
// if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -61,7 +62,8 @@ func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) bool {
return DirExists(a.t, path, msgAndArgs...)
}
// DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
// DirExistsf checks whether a directory exists in the given path. It also fails
// if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -167,7 +169,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
// EqualValuesf asserts that two objects are equal or convertable to the same types
// and equal.
//
// a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123))
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -202,6 +204,66 @@ func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool {
return Error(a.t, err, msgAndArgs...)
}
// ErrorAs asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value.
// This is a wrapper for errors.As.
func (a *Assertions) ErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorAs(a.t, err, target, msgAndArgs...)
}
// ErrorAsf asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value.
// This is a wrapper for errors.As.
func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorAsf(a.t, err, target, msg, args...)
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContains(err, expectedErrorSubString)
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContains(a.t, theError, contains, msgAndArgs...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorContainsf(a.t, theError, contains, msg, args...)
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorIs(a.t, err, target, msgAndArgs...)
}
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ErrorIsf(a.t, err, target, msg, args...)
}
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
@@ -249,7 +311,7 @@ func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArg
// Exactlyf asserts that two objects are equal in value and type.
//
// a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123))
// a.Exactlyf(int32(123), int64(123), "error message %s", "formatted")
func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -309,7 +371,8 @@ func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool {
return Falsef(a.t, value, msg, args...)
}
// FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
// FileExists checks whether a file exists in the given path. It also fails if
// the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -317,7 +380,8 @@ func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) bool {
return FileExists(a.t, path, msgAndArgs...)
}
// FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
// FileExistsf checks whether a file exists in the given path. It also fails if
// the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -366,7 +430,7 @@ func (a *Assertions) GreaterOrEqualf(e1 interface{}, e2 interface{}, msg string,
// Greaterf asserts that the first element is greater than the second
//
// a.Greaterf(2, 1, "error message %s", "formatted")
// a.Greaterf(float64(2, "error message %s", "formatted"), float64(1))
// a.Greaterf(float64(2), float64(1), "error message %s", "formatted")
// a.Greaterf("b", "a", "error message %s", "formatted")
func (a *Assertions) Greaterf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@@ -443,7 +507,7 @@ func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url stri
//
// a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -467,7 +531,7 @@ func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url s
//
// a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -475,6 +539,30 @@ func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url
return HTTPRedirectf(a.t, handler, method, url, values, msg, args...)
}
// HTTPStatusCode asserts that a specified handler returns a specified status code.
//
// a.HTTPStatusCode(myHandler, "GET", "/notImplemented", nil, 501)
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPStatusCode(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPStatusCode(a.t, handler, method, url, values, statuscode, msgAndArgs...)
}
// HTTPStatusCodef asserts that a specified handler returns a specified status code.
//
// a.HTTPStatusCodef(myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPStatusCodef(handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPStatusCodef(a.t, handler, method, url, values, statuscode, msg, args...)
}
// HTTPSuccess asserts that a specified handler returns a success status code.
//
// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil)
@@ -511,7 +599,7 @@ func (a *Assertions) Implements(interfaceObject interface{}, object interface{},
// Implementsf asserts that an object is implemented by the specified interface.
//
// a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject))
// a.Implementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -521,7 +609,7 @@ func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}
// InDelta asserts that the two numerals are within delta of each other.
//
// a.InDelta(math.Pi, (22 / 7.0), 0.01)
// a.InDelta(math.Pi, 22/7.0, 0.01)
func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -563,7 +651,7 @@ func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, del
// InDeltaf asserts that the two numerals are within delta of each other.
//
// a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01)
// a.InDeltaf(math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
@@ -603,6 +691,102 @@ func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilo
return InEpsilonf(a.t, expected, actual, epsilon, msg, args...)
}
// IsDecreasing asserts that the collection is decreasing
//
// a.IsDecreasing([]int{2, 1, 0})
// a.IsDecreasing([]float{2, 1})
// a.IsDecreasing([]string{"b", "a"})
func (a *Assertions) IsDecreasing(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsDecreasing(a.t, object, msgAndArgs...)
}
// IsDecreasingf asserts that the collection is decreasing
//
// a.IsDecreasingf([]int{2, 1, 0}, "error message %s", "formatted")
// a.IsDecreasingf([]float{2, 1}, "error message %s", "formatted")
// a.IsDecreasingf([]string{"b", "a"}, "error message %s", "formatted")
func (a *Assertions) IsDecreasingf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsDecreasingf(a.t, object, msg, args...)
}
// IsIncreasing asserts that the collection is increasing
//
// a.IsIncreasing([]int{1, 2, 3})
// a.IsIncreasing([]float{1, 2})
// a.IsIncreasing([]string{"a", "b"})
func (a *Assertions) IsIncreasing(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsIncreasing(a.t, object, msgAndArgs...)
}
// IsIncreasingf asserts that the collection is increasing
//
// a.IsIncreasingf([]int{1, 2, 3}, "error message %s", "formatted")
// a.IsIncreasingf([]float{1, 2}, "error message %s", "formatted")
// a.IsIncreasingf([]string{"a", "b"}, "error message %s", "formatted")
func (a *Assertions) IsIncreasingf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsIncreasingf(a.t, object, msg, args...)
}
// IsNonDecreasing asserts that the collection is not decreasing
//
// a.IsNonDecreasing([]int{1, 1, 2})
// a.IsNonDecreasing([]float{1, 2})
// a.IsNonDecreasing([]string{"a", "b"})
func (a *Assertions) IsNonDecreasing(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNonDecreasing(a.t, object, msgAndArgs...)
}
// IsNonDecreasingf asserts that the collection is not decreasing
//
// a.IsNonDecreasingf([]int{1, 1, 2}, "error message %s", "formatted")
// a.IsNonDecreasingf([]float{1, 2}, "error message %s", "formatted")
// a.IsNonDecreasingf([]string{"a", "b"}, "error message %s", "formatted")
func (a *Assertions) IsNonDecreasingf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNonDecreasingf(a.t, object, msg, args...)
}
// IsNonIncreasing asserts that the collection is not increasing
//
// a.IsNonIncreasing([]int{2, 1, 1})
// a.IsNonIncreasing([]float{2, 1})
// a.IsNonIncreasing([]string{"b", "a"})
func (a *Assertions) IsNonIncreasing(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNonIncreasing(a.t, object, msgAndArgs...)
}
// IsNonIncreasingf asserts that the collection is not increasing
//
// a.IsNonIncreasingf([]int{2, 1, 1}, "error message %s", "formatted")
// a.IsNonIncreasingf([]float{2, 1}, "error message %s", "formatted")
// a.IsNonIncreasingf([]string{"b", "a"}, "error message %s", "formatted")
func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsNonIncreasingf(a.t, object, msg, args...)
}
// IsType asserts that the specified objects are of the same type.
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@@ -639,22 +823,6 @@ func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ..
return JSONEqf(a.t, expected, actual, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return YAMLEq(a.t, expected, actual, msgAndArgs...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEqf(expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return YAMLEqf(a.t, expected, actual, msg, args...)
}
// Len asserts that the specified object has specific length.
// Len also fails if the object has a type that len() not accept.
//
@@ -718,7 +886,7 @@ func (a *Assertions) LessOrEqualf(e1 interface{}, e2 interface{}, msg string, ar
// Lessf asserts that the first element is less than the second
//
// a.Lessf(1, 2, "error message %s", "formatted")
// a.Lessf(float64(1, "error message %s", "formatted"), float64(2))
// a.Lessf(float64(1), float64(2), "error message %s", "formatted")
// a.Lessf("a", "b", "error message %s", "formatted")
func (a *Assertions) Lessf(e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@@ -727,6 +895,50 @@ func (a *Assertions) Lessf(e1 interface{}, e2 interface{}, msg string, args ...i
return Lessf(a.t, e1, e2, msg, args...)
}
// Negative asserts that the specified element is negative
//
// a.Negative(-1)
// a.Negative(-1.23)
func (a *Assertions) Negative(e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Negative(a.t, e, msgAndArgs...)
}
// Negativef asserts that the specified element is negative
//
// a.Negativef(-1, "error message %s", "formatted")
// a.Negativef(-1.23, "error message %s", "formatted")
func (a *Assertions) Negativef(e interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Negativef(a.t, e, msg, args...)
}
// Never asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// a.Never(func() bool { return false; }, time.Second, 10*time.Millisecond)
func (a *Assertions) Never(condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Never(a.t, condition, waitFor, tick, msgAndArgs...)
}
// Neverf asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// a.Neverf(func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
func (a *Assertions) Neverf(condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Neverf(a.t, condition, waitFor, tick, msg, args...)
}
// Nil asserts that the specified object is nil.
//
// a.Nil(err)
@@ -747,6 +959,24 @@ func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) b
return Nilf(a.t, object, msg, args...)
}
// NoDirExists checks whether a directory does not exist in the given path.
// It fails if the path points to an existing _directory_ only.
func (a *Assertions) NoDirExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoDirExists(a.t, path, msgAndArgs...)
}
// NoDirExistsf checks whether a directory does not exist in the given path.
// It fails if the path points to an existing _directory_ only.
func (a *Assertions) NoDirExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoDirExistsf(a.t, path, msg, args...)
}
// NoError asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
@@ -773,6 +1003,24 @@ func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool {
return NoErrorf(a.t, err, msg, args...)
}
// NoFileExists checks whether a file does not exist in a given path. It fails
// if the path points to an existing _file_ only.
func (a *Assertions) NoFileExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoFileExists(a.t, path, msgAndArgs...)
}
// NoFileExistsf checks whether a file does not exist in a given path. It fails
// if the path points to an existing _file_ only.
func (a *Assertions) NoFileExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoFileExistsf(a.t, path, msg, args...)
}
// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
@@ -838,6 +1086,26 @@ func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndAr
return NotEqual(a.t, expected, actual, msgAndArgs...)
}
// NotEqualValues asserts that two objects are not equal even when converted to the same type
//
// a.NotEqualValues(obj1, obj2)
func (a *Assertions) NotEqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEqualValues(a.t, expected, actual, msgAndArgs...)
}
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
//
// a.NotEqualValuesf(obj1, obj2, "error message %s", "formatted")
func (a *Assertions) NotEqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEqualValuesf(a.t, expected, actual, msg, args...)
}
// NotEqualf asserts that the specified values are NOT equal.
//
// a.NotEqualf(obj1, obj2, "error message %s", "formatted")
@@ -851,6 +1119,24 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
return NotEqualf(a.t, expected, actual, msg, args...)
}
// NotErrorIs asserts that at none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorIs(a.t, err, target, msgAndArgs...)
}
// NotErrorIsf asserts that at none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotErrorIsf(a.t, err, target, msg, args...)
}
// NotNil asserts that the specified object is not nil.
//
// a.NotNil(err)
@@ -904,7 +1190,7 @@ func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...in
// NotRegexpf asserts that a specified regexp does not match a string.
//
// a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting")
// a.NotRegexpf(regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted")
// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted")
func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@@ -913,6 +1199,32 @@ func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, arg
return NotRegexpf(a.t, rx, str, msg, args...)
}
// NotSame asserts that two pointers do not reference the same object.
//
// a.NotSame(ptr1, ptr2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
func (a *Assertions) NotSame(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotSame(a.t, expected, actual, msgAndArgs...)
}
// NotSamef asserts that two pointers do not reference the same object.
//
// a.NotSamef(ptr1, ptr2, "error message %s", "formatted")
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotSamef(a.t, expected, actual, msg, args...)
}
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
@@ -961,6 +1273,30 @@ func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool {
return Panics(a.t, f, msgAndArgs...)
}
// PanicsWithError asserts that the code inside the specified PanicTestFunc
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// a.PanicsWithError("crazy error", func(){ GoCrazy() })
func (a *Assertions) PanicsWithError(errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return PanicsWithError(a.t, errString, f, msgAndArgs...)
}
// PanicsWithErrorf asserts that the code inside the specified PanicTestFunc
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// a.PanicsWithErrorf("crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func (a *Assertions) PanicsWithErrorf(errString string, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return PanicsWithErrorf(a.t, errString, f, msg, args...)
}
// PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
@@ -993,6 +1329,28 @@ func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) b
return Panicsf(a.t, f, msg, args...)
}
// Positive asserts that the specified element is positive
//
// a.Positive(1)
// a.Positive(1.23)
func (a *Assertions) Positive(e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Positive(a.t, e, msgAndArgs...)
}
// Positivef asserts that the specified element is positive
//
// a.Positivef(1, "error message %s", "formatted")
// a.Positivef(1.23, "error message %s", "formatted")
func (a *Assertions) Positivef(e interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Positivef(a.t, e, msg, args...)
}
// Regexp asserts that a specified regexp matches a string.
//
// a.Regexp(regexp.MustCompile("start"), "it's starting")
@@ -1006,7 +1364,7 @@ func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...inter
// Regexpf asserts that a specified regexp matches a string.
//
// a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting")
// a.Regexpf(regexp.MustCompile("start"), "it's starting", "error message %s", "formatted")
// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted")
func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
@@ -1103,6 +1461,22 @@ func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta
return WithinDurationf(a.t, expected, actual, delta, msg, args...)
}
// YAMLEq asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return YAMLEq(a.t, expected, actual, msgAndArgs...)
}
// YAMLEqf asserts that two YAML strings are equivalent.
func (a *Assertions) YAMLEqf(expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return YAMLEqf(a.t, expected, actual, msg, args...)
}
// Zero asserts that i is the zero value for its type.
func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {

View File

@@ -5,305 +5,77 @@ import (
"reflect"
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (int, bool) {
switch kind {
case reflect.Int:
{
intobj1 := obj1.(int)
intobj2 := obj2.(int)
if intobj1 > intobj2 {
return -1, true
}
if intobj1 == intobj2 {
return 0, true
}
if intobj1 < intobj2 {
return 1, true
}
}
case reflect.Int8:
{
int8obj1 := obj1.(int8)
int8obj2 := obj2.(int8)
if int8obj1 > int8obj2 {
return -1, true
}
if int8obj1 == int8obj2 {
return 0, true
}
if int8obj1 < int8obj2 {
return 1, true
}
}
case reflect.Int16:
{
int16obj1 := obj1.(int16)
int16obj2 := obj2.(int16)
if int16obj1 > int16obj2 {
return -1, true
}
if int16obj1 == int16obj2 {
return 0, true
}
if int16obj1 < int16obj2 {
return 1, true
}
}
case reflect.Int32:
{
int32obj1 := obj1.(int32)
int32obj2 := obj2.(int32)
if int32obj1 > int32obj2 {
return -1, true
}
if int32obj1 == int32obj2 {
return 0, true
}
if int32obj1 < int32obj2 {
return 1, true
}
}
case reflect.Int64:
{
int64obj1 := obj1.(int64)
int64obj2 := obj2.(int64)
if int64obj1 > int64obj2 {
return -1, true
}
if int64obj1 == int64obj2 {
return 0, true
}
if int64obj1 < int64obj2 {
return 1, true
}
}
case reflect.Uint:
{
uintobj1 := obj1.(uint)
uintobj2 := obj2.(uint)
if uintobj1 > uintobj2 {
return -1, true
}
if uintobj1 == uintobj2 {
return 0, true
}
if uintobj1 < uintobj2 {
return 1, true
}
}
case reflect.Uint8:
{
uint8obj1 := obj1.(uint8)
uint8obj2 := obj2.(uint8)
if uint8obj1 > uint8obj2 {
return -1, true
}
if uint8obj1 == uint8obj2 {
return 0, true
}
if uint8obj1 < uint8obj2 {
return 1, true
}
}
case reflect.Uint16:
{
uint16obj1 := obj1.(uint16)
uint16obj2 := obj2.(uint16)
if uint16obj1 > uint16obj2 {
return -1, true
}
if uint16obj1 == uint16obj2 {
return 0, true
}
if uint16obj1 < uint16obj2 {
return 1, true
}
}
case reflect.Uint32:
{
uint32obj1 := obj1.(uint32)
uint32obj2 := obj2.(uint32)
if uint32obj1 > uint32obj2 {
return -1, true
}
if uint32obj1 == uint32obj2 {
return 0, true
}
if uint32obj1 < uint32obj2 {
return 1, true
}
}
case reflect.Uint64:
{
uint64obj1 := obj1.(uint64)
uint64obj2 := obj2.(uint64)
if uint64obj1 > uint64obj2 {
return -1, true
}
if uint64obj1 == uint64obj2 {
return 0, true
}
if uint64obj1 < uint64obj2 {
return 1, true
}
}
case reflect.Float32:
{
float32obj1 := obj1.(float32)
float32obj2 := obj2.(float32)
if float32obj1 > float32obj2 {
return -1, true
}
if float32obj1 == float32obj2 {
return 0, true
}
if float32obj1 < float32obj2 {
return 1, true
}
}
case reflect.Float64:
{
float64obj1 := obj1.(float64)
float64obj2 := obj2.(float64)
if float64obj1 > float64obj2 {
return -1, true
}
if float64obj1 == float64obj2 {
return 0, true
}
if float64obj1 < float64obj2 {
return 1, true
}
}
case reflect.String:
{
stringobj1 := obj1.(string)
stringobj2 := obj2.(string)
if stringobj1 > stringobj2 {
return -1, true
}
if stringobj1 == stringobj2 {
return 0, true
}
if stringobj1 < stringobj2 {
return 1, true
}
}
// isOrdered checks that collection contains orderable elements.
func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
objKind := reflect.TypeOf(object).Kind()
if objKind != reflect.Slice && objKind != reflect.Array {
return false
}
return 0, false
}
objValue := reflect.ValueOf(object)
objLen := objValue.Len()
// Greater asserts that the first element is greater than the second
//
// assert.Greater(t, 2, 1)
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
if objLen <= 1 {
return true
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
value := objValue.Index(0)
valueInterface := value.Interface()
firstValueKind := value.Kind()
res, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
for i := 1; i < objLen; i++ {
prevValue := value
prevValueInterface := valueInterface
if res != -1 {
return Fail(t, fmt.Sprintf("\"%v\" is not greater than \"%v\"", e1, e2), msgAndArgs...)
value = objValue.Index(i)
valueInterface = value.Interface()
compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...)
}
if !containsValue(allowedComparesResults, compareResult) {
return Fail(t, fmt.Sprintf(failMessage, prevValue, value), msgAndArgs...)
}
}
return true
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
// IsIncreasing asserts that the collection is increasing
//
// assert.GreaterOrEqual(t, 2, 1)
// assert.GreaterOrEqual(t, 2, 2)
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
res, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
if res != -1 && res != 0 {
return Fail(t, fmt.Sprintf("\"%v\" is not greater than or equal to \"%v\"", e1, e2), msgAndArgs...)
}
return true
// assert.IsIncreasing(t, []int{1, 2, 3})
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
// IsNonIncreasing asserts that the collection is not increasing
//
// assert.Less(t, 1, 2)
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
res, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
if res != 1 {
return Fail(t, fmt.Sprintf("\"%v\" is not less than \"%v\"", e1, e2), msgAndArgs...)
}
return true
// assert.IsNonIncreasing(t, []int{2, 1, 1})
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
// IsDecreasing asserts that the collection is decreasing
//
// assert.LessOrEqual(t, 1, 2)
// assert.LessOrEqual(t, 2, 2)
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
res, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
if res != 1 && res != 0 {
return Fail(t, fmt.Sprintf("\"%v\" is not less than or equal to \"%v\"", e1, e2), msgAndArgs...)
}
return true
// assert.IsDecreasing(t, []int{2, 1, 0})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
//
// assert.IsNonDecreasing(t, []int{1, 1, 2})
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -11,6 +11,7 @@ import (
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strings"
"time"
"unicode"
@@ -18,10 +19,10 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v2"
yaml "gopkg.in/yaml.v3"
)
//go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_format.go.tmpl
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
// TestingT is an interface wrapper around *testing.T
type TestingT interface {
@@ -44,7 +45,7 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// Comparison a custom function that returns true on success and false on failure
// Comparison is a custom function that returns true on success and false on failure
type Comparison func() (success bool)
/*
@@ -103,11 +104,11 @@ the problem actually occurred in calling code.*/
// failed.
func CallerInfo() []string {
pc := uintptr(0)
file := ""
line := 0
ok := false
name := ""
var pc uintptr
var ok bool
var file string
var line int
var name string
callers := []string{}
for i := 0; ; i++ {
@@ -171,8 +172,8 @@ func isTest(name, prefix string) bool {
if len(name) == len(prefix) { // "Test" is ok
return true
}
rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
return !unicode.IsLower(rune)
r, _ := utf8.DecodeRuneInString(name[len(prefix):])
return !unicode.IsLower(r)
}
func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
@@ -351,6 +352,19 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
}
// validateEqualArgs checks whether provided arguments can be safely used in the
// Equal/NotEqual functions.
func validateEqualArgs(expected, actual interface{}) error {
if expected == nil && actual == nil {
return nil
}
if isFunction(expected) || isFunction(actual) {
return errors.New("cannot take func type as argument")
}
return nil
}
// Same asserts that two pointers reference the same object.
//
// assert.Same(t, ptr1, ptr2)
@@ -362,18 +376,7 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
h.Helper()
}
expectedPtr, actualPtr := reflect.ValueOf(expected), reflect.ValueOf(actual)
if expectedPtr.Kind() != reflect.Ptr || actualPtr.Kind() != reflect.Ptr {
return Fail(t, "Invalid operation: both arguments must be pointers", msgAndArgs...)
}
expectedType, actualType := reflect.TypeOf(expected), reflect.TypeOf(actual)
if expectedType != actualType {
return Fail(t, fmt.Sprintf("Pointer expected to be of type %v, but was %v",
expectedType, actualType), msgAndArgs...)
}
if expected != actual {
if !samePointers(expected, actual) {
return Fail(t, fmt.Sprintf("Not same: \n"+
"expected: %p %#v\n"+
"actual : %p %#v", expected, expected, actual, actual), msgAndArgs...)
@@ -382,6 +385,42 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
return true
}
// NotSame asserts that two pointers do not reference the same object.
//
// assert.NotSame(t, ptr1, ptr2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if samePointers(expected, actual) {
return Fail(t, fmt.Sprintf(
"Expected and actual point to the same object: %p %#v",
expected, expected), msgAndArgs...)
}
return true
}
// samePointers compares two generic interface objects and returns whether
// they point to the same object
func samePointers(first, second interface{}) bool {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false
}
firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
if firstType != secondType {
return false
}
// compare pointer addresses
return first == second
}
// formatUnequalValues takes two values of arbitrary types and returns string
// representations appropriate to be presented to the user.
//
@@ -390,12 +429,27 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
// to a type conversion in the Go grammar.
func formatUnequalValues(expected, actual interface{}) (e string, a string) {
if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
return fmt.Sprintf("%T(%#v)", expected, expected),
fmt.Sprintf("%T(%#v)", actual, actual)
return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected)),
fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual))
}
switch expected.(type) {
case time.Duration:
return fmt.Sprintf("%v", expected), fmt.Sprintf("%v", actual)
}
return truncatingFormat(expected), truncatingFormat(actual)
}
return fmt.Sprintf("%#v", expected),
fmt.Sprintf("%#v", actual)
// truncatingFormat formats the data and truncates it if it's too long.
//
// This helps keep formatted error messages lines from exceeding the
// bufio.MaxScanTokenSize max line length that the go testing framework imposes.
func truncatingFormat(data interface{}) string {
value := fmt.Sprintf("%#v", data)
max := bufio.MaxScanTokenSize - 100 // Give us some space the type info too if needed.
if len(value) > max {
value = value[0:max] + "<... truncated>"
}
return value
}
// EqualValues asserts that two objects are equal or convertable to the same types
@@ -442,12 +496,12 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
//
// assert.NotNil(t, err)
func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !isNil(object) {
return true
}
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, "Expected value not to be nil.", msgAndArgs...)
}
@@ -488,12 +542,12 @@ func isNil(object interface{}) bool {
//
// assert.Nil(t, err)
func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isNil(object) {
return true
}
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...)
}
@@ -530,12 +584,11 @@ func isEmpty(object interface{}) bool {
//
// assert.Empty(t, obj)
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
pass := isEmpty(object)
if !pass {
if h, ok := t.(tHelper); ok {
h.Helper()
}
Fail(t, fmt.Sprintf("Should be empty, but was %v", object), msgAndArgs...)
}
@@ -550,12 +603,11 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// assert.Equal(t, "two", obj[1])
// }
func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
pass := !isEmpty(object)
if !pass {
if h, ok := t.(tHelper); ok {
h.Helper()
}
Fail(t, fmt.Sprintf("Should NOT be empty, but was %v", object), msgAndArgs...)
}
@@ -598,16 +650,10 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
//
// assert.True(t, myBool)
func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
if value != true {
if !value {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, "Should be true", msgAndArgs...)
}
@@ -619,11 +665,10 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
//
// assert.False(t, myBool)
func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if value != false {
if value {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, "Should be false", msgAndArgs...)
}
@@ -654,14 +699,33 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
}
// NotEqualValues asserts that two objects are not equal even when converted to the same type
//
// assert.NotEqualValues(t, obj1, obj2)
func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if ObjectsAreEqualValues(expected, actual) {
return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...)
}
return true
}
// containsElement try loop over the list check if the list includes the element.
// return (false, false) if impossible.
// return (true, false) if element was not found.
// return (true, true) if element was found.
func includeElement(list interface{}, element interface{}) (ok, found bool) {
func containsElement(list interface{}, element interface{}) (ok, found bool) {
listValue := reflect.ValueOf(list)
listKind := reflect.TypeOf(list).Kind()
listType := reflect.TypeOf(list)
if listType == nil {
return false, false
}
listKind := listType.Kind()
defer func() {
if e := recover(); e != nil {
ok = false
@@ -704,12 +768,12 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
if !found {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", s, contains), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v does not contain %#v", s, contains), msgAndArgs...)
}
return true
@@ -727,7 +791,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}
ok, found := includeElement(s, contains)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
}
@@ -771,7 +835,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -792,7 +856,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
h.Helper()
}
if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
@@ -815,7 +879,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
ok, found := includeElement(list, element)
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
}
@@ -840,27 +904,39 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
return true
}
aKind := reflect.TypeOf(listA).Kind()
bKind := reflect.TypeOf(listB).Kind()
if aKind != reflect.Array && aKind != reflect.Slice {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", listA, aKind), msgAndArgs...)
if !isList(t, listA, msgAndArgs...) || !isList(t, listB, msgAndArgs...) {
return false
}
if bKind != reflect.Array && bKind != reflect.Slice {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", listB, bKind), msgAndArgs...)
extraA, extraB := diffLists(listA, listB)
if len(extraA) == 0 && len(extraB) == 0 {
return true
}
return Fail(t, formatListDiff(listA, listB, extraA, extraB), msgAndArgs...)
}
// isList checks that the provided value is array or slice.
func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
kind := reflect.TypeOf(list).Kind()
if kind != reflect.Array && kind != reflect.Slice {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind),
msgAndArgs...)
}
return true
}
// diffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B.
// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and
// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored.
func diffLists(listA, listB interface{}) (extraA, extraB []interface{}) {
aValue := reflect.ValueOf(listA)
bValue := reflect.ValueOf(listB)
aLen := aValue.Len()
bLen := bValue.Len()
if aLen != bLen {
return Fail(t, fmt.Sprintf("lengths don't match: %d != %d", aLen, bLen), msgAndArgs...)
}
// Mark indexes in bValue that we already used
visited := make([]bool, bLen)
for i := 0; i < aLen; i++ {
@@ -877,11 +953,38 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
}
}
if !found {
return Fail(t, fmt.Sprintf("element %s appears more times in %s than in %s", element, aValue, bValue), msgAndArgs...)
extraA = append(extraA, element)
}
}
return true
for j := 0; j < bLen; j++ {
if visited[j] {
continue
}
extraB = append(extraB, bValue.Index(j).Interface())
}
return
}
func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string {
var msg bytes.Buffer
msg.WriteString("elements differ")
if len(extraA) > 0 {
msg.WriteString("\n\nextra elements in list A:\n")
msg.WriteString(spewConfig.Sdump(extraA))
}
if len(extraB) > 0 {
msg.WriteString("\n\nextra elements in list B:\n")
msg.WriteString(spewConfig.Sdump(extraB))
}
msg.WriteString("\n\nlistA:\n")
msg.WriteString(spewConfig.Sdump(listA))
msg.WriteString("\n\nlistB:\n")
msg.WriteString(spewConfig.Sdump(listB))
return msg.String()
}
// Condition uses a Comparison to assert a complex condition.
@@ -901,25 +1004,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
type PanicTestFunc func()
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
func didPanic(f PanicTestFunc) (bool, interface{}) {
didPanic := false
var message interface{}
func() {
defer func() {
if message = recover(); message != nil {
didPanic = true
}
}()
// call the target function
f()
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
didPanic = true
defer func() {
message = recover()
if didPanic {
stack = string(debug.Stack())
}
}()
return didPanic, message
// call the target function
f()
didPanic = false
return
}
// Panics asserts that the code inside the specified PanicTestFunc panics.
@@ -930,7 +1029,7 @@ func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
h.Helper()
}
if funcDidPanic, panicValue := didPanic(f); !funcDidPanic {
if funcDidPanic, panicValue, _ := didPanic(f); !funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...)
}
@@ -946,12 +1045,34 @@ func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndAr
h.Helper()
}
funcDidPanic, panicValue := didPanic(f)
funcDidPanic, panicValue, panickedStack := didPanic(f)
if !funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...)
}
if panicValue != expected {
return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%#v\n\tPanic value:\t%#v", f, expected, panicValue), msgAndArgs...)
return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", f, expected, panicValue, panickedStack), msgAndArgs...)
}
return true
}
// PanicsWithError asserts that the code inside the specified PanicTestFunc
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() })
func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
funcDidPanic, panicValue, panickedStack := didPanic(f)
if !funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...)
}
panicErr, ok := panicValue.(error)
if !ok || panicErr.Error() != errString {
return Fail(t, fmt.Sprintf("func %#v should panic with error message:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", f, errString, panicValue, panickedStack), msgAndArgs...)
}
return true
@@ -965,8 +1086,8 @@ func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
h.Helper()
}
if funcDidPanic, panicValue := didPanic(f); funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should not panic\n\tPanic value:\t%v", f, panicValue), msgAndArgs...)
if funcDidPanic, panicValue, panickedStack := didPanic(f); funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should not panic\n\tPanic value:\t%v\n\tPanic stack:\t%s", f, panicValue, panickedStack), msgAndArgs...)
}
return true
@@ -993,6 +1114,8 @@ func toFloat(x interface{}) (float64, bool) {
xok := true
switch xn := x.(type) {
case uint:
xf = float64(xn)
case uint8:
xf = float64(xn)
case uint16:
@@ -1014,7 +1137,7 @@ func toFloat(x interface{}) (float64, bool) {
case float32:
xf = float64(xn)
case float64:
xf = float64(xn)
xf = xn
case time.Duration:
xf = float64(xn)
default:
@@ -1026,7 +1149,7 @@ func toFloat(x interface{}) (float64, bool) {
// InDelta asserts that the two numerals are within delta of each other.
//
// assert.InDelta(t, math.Pi, (22 / 7.0), 0.01)
// assert.InDelta(t, math.Pi, 22/7.0, 0.01)
func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1036,11 +1159,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
bf, bok := toFloat(actual)
if !aok || !bok {
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
return Fail(t, "Parameters must be numerical", msgAndArgs...)
}
if math.IsNaN(af) && math.IsNaN(bf) {
return true
}
if math.IsNaN(af) {
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
return Fail(t, "Expected must not be NaN", msgAndArgs...)
}
if math.IsNaN(bf) {
@@ -1063,7 +1190,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1125,15 +1252,21 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
func calcRelativeError(expected, actual interface{}) (float64, error) {
af, aok := toFloat(expected)
if !aok {
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
bf, bok := toFloat(actual)
if !aok || !bok {
return 0, fmt.Errorf("Parameters must be numerical")
}
if math.IsNaN(af) && math.IsNaN(bf) {
return 0, nil
}
if math.IsNaN(af) {
return 0, errors.New("expected value must not be NaN")
}
if af == 0 {
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
}
bf, bok := toFloat(actual)
if !bok {
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
if math.IsNaN(bf) {
return 0, errors.New("actual value must not be NaN")
}
return math.Abs(af-bf) / math.Abs(af), nil
@@ -1144,6 +1277,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
if h, ok := t.(tHelper); ok {
h.Helper()
}
if math.IsNaN(epsilon) {
return Fail(t, "epsilon must not be NaN")
}
actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil {
return Fail(t, err.Error(), msgAndArgs...)
@@ -1164,7 +1300,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
return Fail(t, "Parameters must be slice", msgAndArgs...)
}
actualSlice := reflect.ValueOf(actual)
@@ -1191,10 +1327,10 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
// assert.Equal(t, expectedObj, actualObj)
// }
func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err != nil {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...)
}
@@ -1208,11 +1344,10 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
// assert.Equal(t, expectedError, err)
// }
func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err == nil {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, "An error is expected but got nil.", msgAndArgs...)
}
@@ -1242,6 +1377,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
return true
}
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) {
return false
}
actual := theError.Error()
if !strings.Contains(actual, contains) {
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
}
return true
}
// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
@@ -1314,7 +1470,8 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
return true
}
// FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
// FileExists checks whether a file exists in the given path. It also fails if
// the path points to a directory or there is an error when trying to check the file.
func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1332,7 +1489,24 @@ func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
return true
}
// DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
// NoFileExists checks whether a file does not exist in a given path. It fails
// if the path points to an existing _file_ only.
func NoFileExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
info, err := os.Lstat(path)
if err != nil {
return true
}
if info.IsDir() {
return true
}
return Fail(t, fmt.Sprintf("file %q exists", path), msgAndArgs...)
}
// DirExists checks whether a directory exists in the given path. It also fails
// if the path is a file rather a directory or there is an error checking whether it exists.
func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1350,6 +1524,25 @@ func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
return true
}
// NoDirExists checks whether a directory does not exist in the given path.
// It fails if the path points to an existing _directory_ only.
func NoDirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
info, err := os.Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return true
}
return true
}
if !info.IsDir() {
return true
}
return Fail(t, fmt.Sprintf("directory %q exists", path), msgAndArgs...)
}
// JSONEq asserts that two JSON strings are equivalent.
//
// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
@@ -1418,12 +1611,17 @@ func diff(expected interface{}, actual interface{}) string {
}
var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
switch et {
case reflect.TypeOf(""):
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
case reflect.TypeOf(time.Time{}):
e = spewConfigStringerEnabled.Sdump(expected)
a = spewConfigStringerEnabled.Sdump(actual)
default:
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
@@ -1439,15 +1637,6 @@ func diff(expected interface{}, actual interface{}) string {
return "\n\nDiff:\n" + diff
}
// validateEqualArgs checks whether provided arguments can be safely used in the
// Equal/NotEqual functions.
func validateEqualArgs(expected, actual interface{}) error {
if isFunction(expected) || isFunction(actual) {
return errors.New("cannot take func type as argument")
}
return nil
}
func isFunction(arg interface{}) bool {
if arg == nil {
return false
@@ -1460,6 +1649,16 @@ var spewConfig = spew.ConfigState{
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
DisableMethods: true,
MaxDepth: 10,
}
var spewConfigStringerEnabled = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
MaxDepth: 10,
}
type tHelper interface {
@@ -1475,24 +1674,137 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
h.Helper()
}
ch := make(chan bool, 1)
timer := time.NewTimer(waitFor)
ticker := time.NewTicker(tick)
checkPassed := make(chan bool)
defer timer.Stop()
ticker := time.NewTicker(tick)
defer ticker.Stop()
defer close(checkPassed)
for {
for tick := ticker.C; ; {
select {
case <-timer.C:
return Fail(t, "Condition never satisfied", msgAndArgs...)
case result := <-checkPassed:
if result {
case <-tick:
tick = nil
go func() { ch <- condition() }()
case v := <-ch:
if v {
return true
}
case <-ticker.C:
go func() {
checkPassed <- condition()
}()
tick = ticker.C
}
}
}
// Never asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond)
func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
ch := make(chan bool, 1)
timer := time.NewTimer(waitFor)
defer timer.Stop()
ticker := time.NewTicker(tick)
defer ticker.Stop()
for tick := ticker.C; ; {
select {
case <-timer.C:
return true
case <-tick:
tick = nil
go func() { ch <- condition() }()
case v := <-ch:
if v {
return Fail(t, "Condition satisfied", msgAndArgs...)
}
tick = ticker.C
}
}
}
// ErrorIs asserts that at least one of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if errors.Is(err, target) {
return true
}
var expectedText string
if target != nil {
expectedText = target.Error()
}
chain := buildErrorChainString(err)
return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
"in chain: %s", expectedText, chain,
), msgAndArgs...)
}
// NotErrorIs asserts that at none of the errors in err's chain matches target.
// This is a wrapper for errors.Is.
func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !errors.Is(err, target) {
return true
}
var expectedText string
if target != nil {
expectedText = target.Error()
}
chain := buildErrorChainString(err)
return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
"in chain: %s", expectedText, chain,
), msgAndArgs...)
}
// ErrorAs asserts that at least one of the errors in err's chain matches target, and if so, sets target to that error value.
// This is a wrapper for errors.As.
func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if errors.As(err, target) {
return true
}
chain := buildErrorChainString(err)
return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+
"in chain: %s", target, chain,
), msgAndArgs...)
}
func buildErrorChainString(err error) string {
if err == nil {
return ""
}
e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
}
return chain
}

View File

@@ -13,4 +13,4 @@ func New(t TestingT) *Assertions {
}
}
//go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs"

View File

@@ -33,7 +33,6 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
return false
}
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
@@ -56,7 +55,6 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
return false
}
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
@@ -79,7 +77,6 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
return false
}
isErrorCode := code >= http.StatusBadRequest
@@ -90,6 +87,28 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
return isErrorCode
}
// HTTPStatusCode asserts that a specified handler returns a specified status code.
//
// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501)
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
code, err := httpCode(handler, method, url, values)
if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
}
successful := code == statuscode
if !successful {
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code))
}
return successful
}
// HTTPBody is a helper that returns HTTP body of the response. It returns
// empty string if building a new request fails.
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {

View File

@@ -13,4 +13,4 @@ func New(t TestingT) *Assertions {
}
}
//go:generate go run ../_codegen/main.go -output-package=require -template=require_forward.go.tmpl -include-format-funcs
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require_forward.go.tmpl -include-format-funcs"

Some files were not shown because too many files have changed in this diff Show More