mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-01 14:32:15 +08:00
Compare commits
132 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3a15cc3add | ||
![]() |
765f6e7c2e | ||
![]() |
e3bdfc1f8e | ||
![]() |
60cd972b7f | ||
![]() |
ee081d0abe | ||
![]() |
d97b4bb81d | ||
![]() |
94aeacf0cb | ||
![]() |
5cb8a081a1 | ||
![]() |
fc00112e47 | ||
![]() |
bbc22fae5b | ||
![]() |
82cb75913d | ||
![]() |
45c4a64b87 | ||
![]() |
bae2579497 | ||
![]() |
1bc01271cb | ||
![]() |
352a71f50c | ||
![]() |
6f9f62e38f | ||
![]() |
0f67d9e8ff | ||
![]() |
f2dd5b63ae | ||
![]() |
54e2d044a2 | ||
![]() |
6298a87298 | ||
![]() |
b6fd25bba4 | ||
![]() |
eef3592576 | ||
![]() |
5d343c12e1 | ||
![]() |
70f52c8a3b | ||
![]() |
429b72265a | ||
![]() |
f60d2dcfca | ||
![]() |
6674cd64eb | ||
![]() |
f218cde69c | ||
![]() |
9ea687eb94 | ||
![]() |
949e4e2e91 | ||
![]() |
515e0269de | ||
![]() |
ae6073c79c | ||
![]() |
b072a08f0b | ||
![]() |
01d8a450d2 | ||
![]() |
da2fd41f79 | ||
![]() |
56e8039093 | ||
![]() |
7b9bc844c1 | ||
![]() |
8acb182820 | ||
![]() |
5726880095 | ||
![]() |
ee459e1b3d | ||
![]() |
6aec3a8bbf | ||
![]() |
74699f0a87 | ||
![]() |
70def39ff9 | ||
![]() |
8e7098a32d | ||
![]() |
e4f02919fd | ||
![]() |
99c96c844e | ||
![]() |
18629aea6d | ||
![]() |
a0060429d1 | ||
![]() |
d946a9ae16 | ||
![]() |
51a2eb5f48 | ||
![]() |
0d4b0a89d8 | ||
![]() |
0e7ccfe3fb | ||
![]() |
5d7230630d | ||
![]() |
6a3cbd6093 | ||
![]() |
7b4e79707b | ||
![]() |
c6643592f6 | ||
![]() |
5de12d0460 | ||
![]() |
0a7205e110 | ||
![]() |
8133dd8299 | ||
![]() |
fdbfff57dc | ||
![]() |
f5fc5e8c44 | ||
![]() |
9f44712b80 | ||
![]() |
1f86168d9d | ||
![]() |
ab25083ed2 | ||
![]() |
9b0aa4d559 | ||
![]() |
03814944a9 | ||
![]() |
3286d5a484 | ||
![]() |
7e970d3c7a | ||
![]() |
d6a92cc5bd | ||
![]() |
325d44d478 | ||
![]() |
0a5f6d3a9d | ||
![]() |
17253ad8bd | ||
![]() |
9f1c387091 | ||
![]() |
9c6f602630 | ||
![]() |
b0dcaabdde | ||
![]() |
460f0ef681 | ||
![]() |
6e16765f60 | ||
![]() |
2b361df19e | ||
![]() |
c8c0a5a094 | ||
![]() |
4a833dd081 | ||
![]() |
81198d9845 | ||
![]() |
6c12d8a71a | ||
![]() |
19b598b672 | ||
![]() |
b6529f05d3 | ||
![]() |
7f76445cc8 | ||
![]() |
b1c01792cd | ||
![]() |
eda03d4338 | ||
![]() |
18070f1f57 | ||
![]() |
7f10c28a37 | ||
![]() |
122531bb27 | ||
![]() |
e6dbcae428 | ||
![]() |
98875de568 | ||
![]() |
c9fd9451af | ||
![]() |
6550b8d680 | ||
![]() |
a60c96c889 | ||
![]() |
86e0a5827e | ||
![]() |
06c399b606 | ||
![]() |
ed117f67a1 | ||
![]() |
880a3299e1 | ||
![]() |
1c408d05be | ||
![]() |
fce495f83e | ||
![]() |
471ca00a64 | ||
![]() |
a2c0749640 | ||
![]() |
37293aeecf | ||
![]() |
7a2d4db6a4 | ||
![]() |
03d2a8bc82 | ||
![]() |
4b51e5c7d1 | ||
![]() |
d15ad682bf | ||
![]() |
130ffcbb53 | ||
![]() |
33cf2f991b | ||
![]() |
a360ea6a6c | ||
![]() |
ae3aa0d3fa | ||
![]() |
811ae0e1be | ||
![]() |
51d6825430 | ||
![]() |
514288c53e | ||
![]() |
957fc0a049 | ||
![]() |
03f94f948a | ||
![]() |
1bc752a2b8 | ||
![]() |
b9db59ba12 | ||
![]() |
c0ef58c363 | ||
![]() |
994adea3b4 | ||
![]() |
fc61cc9be5 | ||
![]() |
22d7338878 | ||
![]() |
3f28515706 | ||
![]() |
7d73ce9caf | ||
![]() |
0758bc961c | ||
![]() |
8472b9ae8a | ||
![]() |
530a018e80 | ||
![]() |
0b594afb4e | ||
![]() |
9d0ea957bb | ||
![]() |
8067785ac4 | ||
![]() |
6ffc8a8388 |
43
.github/workflows/build.yml
vendored
Normal file
43
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: build
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.17
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -v -race ./...
|
||||
|
||||
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
if: success()
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.17.x
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Calc coverage
|
||||
run: |
|
||||
go test -v -covermode=count -coverprofile=coverage.out ./...
|
||||
- name: Convert coverage.out to coverage.lcov
|
||||
uses: jandelgado/gcov2lcov-action@v1.0.6
|
||||
- name: Coveralls
|
||||
uses: coverallsapp/github-action@v1.1.2
|
||||
with:
|
||||
github-token: ${{ secrets.github_token }}
|
||||
path-to-lcov: coverage.lcov
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
cmd/mqtt
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
server/persistence/bolt/testbolt.db
|
||||
|
18
.travis.yml
18
.travis.yml
@@ -1,18 +0,0 @@
|
||||
dist: xenial
|
||||
|
||||
language: go
|
||||
|
||||
env:
|
||||
- GO111MODULE=on
|
||||
|
||||
go:
|
||||
- 1.13.x
|
||||
|
||||
git:
|
||||
depth: 1
|
||||
|
||||
script:
|
||||
- go test -v -race ./... -coverprofile=coverage.txt -covermode=atomic
|
||||
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
FROM golang:1.18.0-alpine3.15 AS builder
|
||||
|
||||
RUN apk update
|
||||
RUN apk add git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod ./
|
||||
COPY go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . ./
|
||||
|
||||
RUN go build -o /app/mochi ./cmd
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=builder /app/mochi .
|
||||
|
||||
# tcp
|
||||
EXPOSE 1883
|
||||
|
||||
# websockets
|
||||
EXPOSE 1882
|
||||
|
||||
# dashboard
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT [ "/mochi" ]
|
85
README.md
85
README.md
@@ -1,10 +1,11 @@
|
||||
|
||||
<p align="center">
|
||||
|
||||
[](https://travis-ci.com/mochi-co/mqtt)
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt)
|
||||
[](https://github.com/mochi-co/mqtt/issues)
|
||||
[](https://codecov.io/gh/mochi-co/mqtt)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt)
|
||||
|
||||
</p>
|
||||
|
||||
@@ -13,6 +14,10 @@
|
||||
|
||||
Mochi MQTT is an embeddable high-performance MQTT broker server written in Go, and compliant with the MQTT v3.0 and v3.1.1 specification for the development of IoT and smarthome projects. The server can be used either as a standalone binary or embedded as a library in your own projects. Mochi MQTT message throughput is comparable with everyone's favourites such as Mosquitto, Mosca, and VerneMQ.
|
||||
|
||||
> #### 📦 💬 See Github Discussions for discussions about releases
|
||||
> Ongoing discussion about current and future releases can be found at https://github.com/mochi-co/mqtt/discussions
|
||||
|
||||
|
||||
#### What is MQTT?
|
||||
MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks. [Learn more](https://mqtt.org/faq)
|
||||
|
||||
@@ -23,11 +28,13 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
|
||||
- Ring Buffer packet codec.
|
||||
- TCP, Websocket, (including SSL/TLS) and Dashboard listeners.
|
||||
- Interfaces for Client Authentication and Topic access control.
|
||||
- Bolt-backed persistence and storage interfaces.
|
||||
- Bolt persistence and storage interfaces (see examples folder).
|
||||
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
|
||||
- Basic Event Hooks (currently `onMessage`)
|
||||
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
|
||||
- ARM32 Compatible.
|
||||
|
||||
#### Roadmap
|
||||
- Please open an issue to request new features or event hooks.
|
||||
- MQTT v5 compatibility?
|
||||
|
||||
#### Using the Broker
|
||||
@@ -47,7 +54,7 @@ import (
|
||||
|
||||
func main() {
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
@@ -105,11 +112,42 @@ err := server.AddListener(tcp, &listeners.Config{
|
||||
#### Event Hooks
|
||||
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
|
||||
|
||||
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
|
||||
##### OnConnect
|
||||
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnDisconnect
|
||||
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
|
||||
|
||||
```go
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnMessage
|
||||
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
|
||||
`server.Events.OnMessage` is called when a Publish packet (message) is received. The method receives the published message and information about the client who published it.
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
|
||||
##### OnProcessMessage
|
||||
`server.Events.OnProcessMessage` is called before a publish packet (message) is processed. Specifically, the method callback is triggered after topic and ACL validation has occurred, but before the headers and payload are processed. You can use this if you want to programmatically change the data of the packet, such as setting it to retain, or altering the QoS flag.
|
||||
|
||||
If an error is returned, the packet will not be modified. and the existing packet will be used. If this is an unwanted outcome, the `mqtt.ErrRejectPacket` error can be returned from the callback, and the packet will be dropped/ignored, any further processing is abandoned.
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
@@ -124,7 +162,34 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
}
|
||||
```
|
||||
|
||||
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
##### OnError
|
||||
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
|
||||
|
||||
##### OnStorage
|
||||
`server.Events.OnStorage` is like `onError`, but receives the output of persistent storage methods.
|
||||
|
||||
|
||||
#### Server Options
|
||||
A few options can be passed to the `mqtt.NewServer(opts *Options)` function in order to override the default broker configuration. Currently these options are:
|
||||
|
||||
|
||||
- BufferSize (default 1024 * 256 bytes) - The default value is sufficient for most messaging sizes, but if you are sending many kilobytes of data (such as images), you should increase this to a value of (n*s) where is the typical size of your message and n is the number of messages you may have backlogged for a client at any given time.
|
||||
- BufferBlockSize (default 1024 * 8) - The minimum size in which R/W data will be allocated. If you are expecting only tiny or large payloads, you can alter this accordingly.
|
||||
|
||||
Any options which is not set or is `0` will use default values.
|
||||
|
||||
```go
|
||||
opts := &mqtt.Options{
|
||||
BufferSize: 512 * 1024,
|
||||
BufferBlockSize: 16 * 1024,
|
||||
}
|
||||
|
||||
s := mqtt.NewServer(opts)
|
||||
```
|
||||
|
||||
> See `examples/tcp/main.go` for an example implementation.
|
||||
|
||||
#### Direct Publishing
|
||||
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
|
||||
@@ -154,7 +219,7 @@ if err != nil {
|
||||
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the test with `python3 client_test.py` from the _interoperability_ folder.
|
||||
|
||||
|
||||
#### Performance (messages/second)
|
||||
#### Performance at v1.0.0
|
||||
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a 13-inch, Early 2015 Macbook Pro (2.7 GHz Intel Core i5). Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better. SEND = Publish throughput, RECV = Subscribe throughput.
|
||||
|
||||
> As usual, any performance benchmarks should be taken with a pinch of salt, but are shown to demonstrate typical throughput compared to the other leading MQTT brokers.
|
||||
|
@@ -33,7 +33,7 @@ func main() {
|
||||
fmt.Println(aurora.Cyan("Websocket"), *wsAddr)
|
||||
fmt.Println(aurora.Cyan("$SYS Dashboard"), *infoAddr)
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr)
|
||||
err := server.AddListener(tcp, nil)
|
||||
if err != nil {
|
||||
|
@@ -24,7 +24,7 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080")
|
||||
err := server.AddListener(stats, nil)
|
||||
|
@@ -27,7 +27,7 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
@@ -44,6 +44,16 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add OnConnect Event Hook
|
||||
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
|
||||
// Add OnDisconnect Event Hook
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
|
||||
// Add OnMessage Event Hook
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
pkx = pk
|
||||
@@ -54,6 +64,12 @@ func main() {
|
||||
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
|
||||
}
|
||||
|
||||
// Example of using AllowClients to selectively deliver/drop messages.
|
||||
// Only a client with the id of `allowed-client` will received messages on the topic.
|
||||
if pkx.TopicName == "a/b/restricted" {
|
||||
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
|
@@ -46,7 +46,7 @@ func main() {
|
||||
// Auth is an example auth provider for the server.
|
||||
type Auth struct{}
|
||||
|
||||
// Auth returns true if a username and password are acceptable.
|
||||
// Authenticate returns true if a username and password are acceptable.
|
||||
// Auth always returns true.
|
||||
func (a *Auth) Authenticate(user, password []byte) bool {
|
||||
return true
|
||||
@@ -55,8 +55,5 @@ func (a *Auth) Authenticate(user, password []byte) bool {
|
||||
// ACL returns true if a user has access permissions to read or write on a topic.
|
||||
// ACL is used to deny access to a specific topic to satisfy Test.test_subscribe_failure.
|
||||
func (a *Auth) ACL(user []byte, topic string, write bool) bool {
|
||||
if topic == "test/nosubscribe" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return topic != "test/nosubscribe"
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("Persistence"))
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
|
@@ -25,7 +25,13 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
|
||||
|
||||
server := mqtt.New()
|
||||
// An example of configuring various server options...
|
||||
options := &mqtt.Options{
|
||||
BufferSize: 0, // Use default values
|
||||
BufferBlockSize: 0, // Use default values
|
||||
}
|
||||
|
||||
server := mqtt.NewServer(options)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
|
@@ -57,7 +57,7 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
|
@@ -24,7 +24,7 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TCP"))
|
||||
|
||||
server := mqtt.New()
|
||||
server := mqtt.NewServer(nil)
|
||||
ws := listeners.NewWebsocket("ws1", ":1882")
|
||||
err := server.AddListener(ws, nil)
|
||||
if err != nil {
|
||||
|
12
go.mod
12
go.mod
@@ -1,16 +1,16 @@
|
||||
module github.com/mochi-co/mqtt
|
||||
|
||||
go 1.17
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/asdine/storm v2.1.2+incompatible
|
||||
github.com/asdine/storm/v3 v3.2.1
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/jinzhu/copier v0.3.4
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
go.etcd.io/bbolt v1.3.6
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
)
|
||||
|
||||
require (
|
||||
|
20
go.sum
20
go.sum
@@ -14,10 +14,10 @@ github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jinzhu/copier v0.3.4 h1:mfU6jI9PtCeUjkjQ322dlff9ELjGDu975C2p/nrubVI=
|
||||
github.com/jinzhu/copier v0.3.4/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
@@ -27,17 +27,17 @@ github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczG
|
||||
github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
|
||||
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU=
|
||||
go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4=
|
||||
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
|
||||
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0 h1:QPlSTtPE2k6PZPasQUbzuK3p9JbS+vMXYVto8g/yrsg=
|
||||
|
@@ -1,29 +1,49 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"github.com/mochi-co/mqtt/server/internal/clients"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
)
|
||||
|
||||
// Events provides callback handlers for different event hooks.
|
||||
type Events struct {
|
||||
OnMessage // published message receieved.
|
||||
OnProcessMessage // published message receieved before evaluation.
|
||||
OnMessage // published message receieved.
|
||||
OnError // server error.
|
||||
OnConnect // client connected.
|
||||
OnDisconnect // client disconnected.
|
||||
}
|
||||
|
||||
// Packets is an alias for packets.Packet.
|
||||
type Packet packets.Packet
|
||||
|
||||
// Client contains limited information about a connected client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
}
|
||||
|
||||
// FromClient returns an event client from a client.
|
||||
func FromClient(cl clients.Client) Client {
|
||||
return Client{
|
||||
ID: cl.ID,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
// Clientlike is an interface for Clients and client-like objects that
|
||||
// are able to describe their client/listener IDs and remote address.
|
||||
type Clientlike interface {
|
||||
Info() Client
|
||||
}
|
||||
|
||||
// OnProcessMessage is called when a publish message is received, allowing modification
|
||||
// of the packet data after ACL checking has occurred but before any data is evaluated
|
||||
// for processing - e.g. for changing the Retain flag. Note, this hook is ONLY called
|
||||
// by connected client publishers, it is not triggered when using the direct
|
||||
// s.Publish method. The function receives the sent message and the
|
||||
// data of the client who published it, and allows the packet to be modified
|
||||
// before it is dispatched to subscribers. If no modification is required, return
|
||||
// the original packet data. If an error occurs, the original packet will
|
||||
// be dispatched as if the event hook had not been triggered.
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
|
||||
// processing of the packet.
|
||||
type OnProcessMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
// OnMessage function is called when a publish message is received. Note,
|
||||
// this hook is ONLY called by connected client publishers, it is not triggered when
|
||||
// using the direct s.Publish method. The function receives the sent message and the
|
||||
@@ -34,3 +54,15 @@ func FromClient(cl clients.Client) Client {
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
type OnMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
// OnConnect is called when a client successfully connects to the broker.
|
||||
type OnConnect func(Client, Packet)
|
||||
|
||||
// OnDisconnect is called when a client disconnects to the broker. An error value
|
||||
// is passed to the function if the client disconnected abnormally, otherwise it
|
||||
// will be nil on a normal disconnect.
|
||||
type OnDisconnect func(Client, error)
|
||||
|
||||
// OnError is called when errors that will not be passed to
|
||||
// OnDisconnect are handled by the server.
|
||||
type OnError func(Client, error)
|
||||
|
@@ -8,33 +8,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
|
||||
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
|
||||
// DefaultBufferSize is the default size of the buffer in bytes.
|
||||
DefaultBufferSize int = 1024 * 256
|
||||
|
||||
ErrOutOfRange = errors.New("Indexes out of range")
|
||||
// DefaultBlockSize is the default size per R/W block in bytes.
|
||||
DefaultBlockSize int = 1024 * 8
|
||||
|
||||
// ErrOutOfRange indicates that the index was out of range.
|
||||
ErrOutOfRange = errors.New("Indexes out of range")
|
||||
|
||||
// ErrInsufficientBytes indicates that there were not enough bytes to return.
|
||||
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
|
||||
)
|
||||
|
||||
// buffer contains core values and methods to be included in a reader or writer.
|
||||
// Buffer is a circular buffer for reading and writing messages.
|
||||
type Buffer struct {
|
||||
Mu sync.RWMutex // the buffer needs it's own mutex to work properly.
|
||||
ID string // the identifier of the buffer. This is used in debug output.
|
||||
size int // the size of the buffer.
|
||||
mask int // a bitmask of the buffer size (size-1).
|
||||
block int // the size of the R/W block.
|
||||
buf []byte // the bytes buffer.
|
||||
tmp []byte // a temporary buffer.
|
||||
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
|
||||
ID string // the identifier of the buffer. This is used in debug output.
|
||||
head int64 // the current position in the sequence - a forever increasing index.
|
||||
tail int64 // the committed position in the sequence - a forever increasing index.
|
||||
rcond *sync.Cond // the sync condition for the buffer reader.
|
||||
wcond *sync.Cond // the sync condition for the buffer writer.
|
||||
done int64 // indicates that the buffer is closed.
|
||||
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
size int // the size of the buffer.
|
||||
mask int // a bitmask of the buffer size (size-1).
|
||||
block int // the size of the R/W block.
|
||||
done uint32 // indicates that the buffer is closed.
|
||||
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
}
|
||||
|
||||
// NewBuffer returns a new instance of buffer. You should call NewReader or
|
||||
// NewWriter instead of this function.
|
||||
func NewBuffer(size, block int) Buffer {
|
||||
func NewBuffer(size, block int) *Buffer {
|
||||
if size == 0 {
|
||||
size = DefaultBufferSize
|
||||
}
|
||||
@@ -47,7 +53,7 @@ func NewBuffer(size, block int) Buffer {
|
||||
size = 2 * block
|
||||
}
|
||||
|
||||
return Buffer{
|
||||
return &Buffer{
|
||||
size: size,
|
||||
mask: size - 1,
|
||||
block: block,
|
||||
@@ -59,14 +65,14 @@ func NewBuffer(size, block int) Buffer {
|
||||
|
||||
// NewBufferFromSlice returns a new instance of buffer using a
|
||||
// pre-existing byte slice.
|
||||
func NewBufferFromSlice(block int, buf []byte) Buffer {
|
||||
func NewBufferFromSlice(block int, buf []byte) *Buffer {
|
||||
l := len(buf)
|
||||
|
||||
if block == 0 {
|
||||
block = DefaultBlockSize
|
||||
}
|
||||
|
||||
b := Buffer{
|
||||
b := &Buffer{
|
||||
size: l,
|
||||
mask: l - 1,
|
||||
block: block,
|
||||
@@ -78,7 +84,7 @@ func NewBufferFromSlice(block int, buf []byte) Buffer {
|
||||
return b
|
||||
}
|
||||
|
||||
// Get will return the tail and head positions of the buffer.
|
||||
// GetPos will return the tail and head positions of the buffer.
|
||||
// This method is for use with testing.
|
||||
func (b *Buffer) GetPos() (int64, int64) {
|
||||
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
|
||||
@@ -127,7 +133,7 @@ func (b *Buffer) awaitEmpty(n int) error {
|
||||
// then wait until tail has moved.
|
||||
b.rcond.L.Lock()
|
||||
for !b.checkEmpty(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.rcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -146,7 +152,7 @@ func (b *Buffer) awaitFilled(n int) error {
|
||||
// the forever-incrementing tail and head integers.
|
||||
b.wcond.L.Lock()
|
||||
for !b.checkFilled(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.wcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -196,7 +202,7 @@ func (b *Buffer) CapDelta() int {
|
||||
|
||||
// Stop signals the buffer to stop processing.
|
||||
func (b *Buffer) Stop() {
|
||||
atomic.StoreInt64(&b.done, 1)
|
||||
atomic.StoreUint32(&b.done, 1)
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestAtomicAlignment(t *testing.T) {
|
||||
var b Buffer
|
||||
|
||||
offset := unsafe.Offsetof(b.head)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"head requires 64-bit alignment for atomic: offset %d", offset)
|
||||
|
||||
offset = unsafe.Offsetof(b.tail)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"tail requires 64-bit alignment for atomic: offset %d", offset)
|
||||
}
|
||||
|
||||
func TestGetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
tail, head := buf.GetPos()
|
||||
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
|
||||
o <- buf.awaitFilled(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
|
||||
o <- buf.awaitEmpty(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
|
||||
o <- buf.CommitTail(5)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
|
||||
func TestStop(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.Stop()
|
||||
require.Equal(t, int64(1), buf.done)
|
||||
require.Equal(t, uint32(1), buf.done)
|
||||
}
|
||||
|
@@ -2,17 +2,23 @@ package circ
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// BytesPool is a pool of []byte.
|
||||
type BytesPool struct {
|
||||
pool sync.Pool
|
||||
pool *sync.Pool
|
||||
used int64
|
||||
}
|
||||
|
||||
// NewBytesPool returns a sync.pool of []byte.
|
||||
func NewBytesPool(n int) BytesPool {
|
||||
return BytesPool{
|
||||
pool: sync.Pool{
|
||||
func NewBytesPool(n int) *BytesPool {
|
||||
if n == 0 {
|
||||
n = DefaultBufferSize
|
||||
}
|
||||
|
||||
return &BytesPool{
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, n)
|
||||
},
|
||||
@@ -22,6 +28,7 @@ func NewBytesPool(n int) BytesPool {
|
||||
|
||||
// Get returns a pooled bytes.Buffer.
|
||||
func (b *BytesPool) Get() []byte {
|
||||
atomic.AddInt64(&b.used, 1)
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
@@ -31,4 +38,10 @@ func (b *BytesPool) Put(x []byte) {
|
||||
x[i] = 0
|
||||
}
|
||||
b.pool.Put(x)
|
||||
atomic.AddInt64(&b.used, -1)
|
||||
}
|
||||
|
||||
// InUse returns the number of pool blocks in use.
|
||||
func (b *BytesPool) InUse() int64 {
|
||||
return atomic.LoadInt64(&b.used)
|
||||
}
|
||||
|
@@ -22,6 +22,7 @@ func TestNewBytesPoolGet(t *testing.T) {
|
||||
buf := bpool.Get()
|
||||
|
||||
require.Equal(t, make([]byte, 256), buf)
|
||||
require.Equal(t, int64(1), bpool.InUse())
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolGet(b *testing.B) {
|
||||
@@ -34,7 +35,9 @@ func BenchmarkBytesPoolGet(b *testing.B) {
|
||||
func TestNewBytesPoolPut(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
require.Equal(t, int64(1), bpool.InUse())
|
||||
bpool.Put(buf)
|
||||
require.Equal(t, int64(0), bpool.InUse())
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolPut(b *testing.B) {
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
// Reader is a circular buffer for reading data from an io.Reader.
|
||||
type Reader struct {
|
||||
Buffer
|
||||
*Buffer
|
||||
}
|
||||
|
||||
// NewReader returns a new Circular Reader.
|
||||
@@ -19,7 +19,7 @@ func NewReader(size, block int) *Reader {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
|
||||
// NewReaderFromSlice returns a new Circular Reader using a pre-existing
|
||||
// byte slice.
|
||||
func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
atomic.StoreInt64(&b.State, 1)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
atomic.StoreUint32(&b.State, 1)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
n, err := r.Read(b.buf[start:end])
|
||||
total += int64(n) // incr total bytes read.
|
||||
if err != nil {
|
||||
return total, nil
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Move the head forward however many bytes were read.
|
||||
|
@@ -2,6 +2,8 @@ package circ
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -34,16 +36,18 @@ func TestReadFrom(t *testing.T) {
|
||||
br := bytes.NewReader(b4)
|
||||
|
||||
_, err := buf.ReadFrom(br)
|
||||
require.NoError(t, err)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buf[:4])
|
||||
require.Equal(t, int64(4), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, int64(8), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.True(t, errors.Is(err, io.EOF))
|
||||
require.Equal(t, int64(12), buf.head)
|
||||
}
|
||||
|
||||
@@ -60,7 +64,7 @@ func TestReadFromWrap(t *testing.T) {
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
go func() {
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -116,7 +120,7 @@ func TestReadEnded(t *testing.T) {
|
||||
o <- err
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
@@ -1,14 +1,14 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Writer is a circular buffer for writing data to an io.Writer.
|
||||
type Writer struct {
|
||||
Buffer
|
||||
*Buffer
|
||||
}
|
||||
|
||||
// NewWriter returns a pointer to a new Circular Writer.
|
||||
@@ -20,7 +20,7 @@ func NewWriter(size, block int) *Writer {
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
|
||||
// NewWriterFromSlice returns a new Circular Writer using a pre-existing
|
||||
// byte slice.
|
||||
func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
@@ -31,11 +31,11 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
}
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||
atomic.StoreInt64(&b.State, 2)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
|
||||
atomic.StoreUint32(&b.State, 2)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
return total, io.EOF
|
||||
}
|
||||
|
||||
@@ -59,14 +59,12 @@ func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||
p = append(p, b.buf[rTail:rHead]...)
|
||||
}
|
||||
|
||||
//fmt.Println("writing", p)
|
||||
n, err = w.Write(p)
|
||||
total += n
|
||||
total += int64(n)
|
||||
if err != nil {
|
||||
fmt.Println("writing err", err)
|
||||
log.Println("error writing to buffer io.Writer;", err)
|
||||
return
|
||||
}
|
||||
//fmt.Println("written", n)
|
||||
|
||||
// Move the tail forward the bytes written and broadcast change.
|
||||
atomic.StoreInt64(&b.tail, tail+int64(n))
|
||||
|
@@ -52,14 +52,14 @@ func TestWriteTo(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
w := bufio.NewWriter(&b)
|
||||
|
||||
nc := make(chan int)
|
||||
nc := make(chan int64)
|
||||
go func() {
|
||||
n, _ := buf.WriteTo(w)
|
||||
nc <- n
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/events"
|
||||
"github.com/mochi-co/mqtt/server/internal/circ"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/internal/topics"
|
||||
@@ -23,7 +24,9 @@ var (
|
||||
// defaultKeepalive is the default connection keepalive value in seconds.
|
||||
defaultKeepalive uint16 = 10
|
||||
|
||||
ErrConnectionClosed = errors.New("Connection not open")
|
||||
// ErrConnectionClosed is returned when operating on a closed
|
||||
// connection and/or when no error cause has been given.
|
||||
ErrConnectionClosed = errors.New("connection not open")
|
||||
)
|
||||
|
||||
// Clients contains a map of the clients known by the broker.
|
||||
@@ -74,7 +77,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
cl.RLock()
|
||||
for _, v := range cl.internal {
|
||||
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
|
||||
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
|
||||
clients = append(clients, v)
|
||||
}
|
||||
}
|
||||
@@ -84,42 +87,43 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
|
||||
// Client contains information about a client known by the broker.
|
||||
type Client struct {
|
||||
sync.RWMutex
|
||||
conn net.Conn // the net.Conn used to establish the connection.
|
||||
r *circ.Reader // a reader for reading incoming bytes.
|
||||
w *circ.Writer // a writer for writing outgoing bytes.
|
||||
ID string // the client id.
|
||||
AC auth.Controller // an auth controller inherited from the listener.
|
||||
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
|
||||
Listener string // the id of the listener the client is connected to.
|
||||
Inflight Inflight // a map of in-flight qos messages.
|
||||
Username []byte // the username the client authenticated with.
|
||||
keepalive uint16 // the number of seconds the connection can wait.
|
||||
cleanSession bool // indicates if the client expects a clean-session.
|
||||
packetID uint32 // the current highest packetID.
|
||||
LWT LWT // the last will and testament for the client.
|
||||
State State // the operational state of the client.
|
||||
system *system.Info // pointers to server system info.
|
||||
LWT LWT // the last will and testament for the client.
|
||||
Inflight *Inflight // a map of in-flight qos messages.
|
||||
sync.RWMutex // mutex
|
||||
Username []byte // the username the client authenticated with.
|
||||
AC auth.Controller // an auth controller inherited from the listener.
|
||||
Listener string // the id of the listener the client is connected to.
|
||||
ID string // the client id.
|
||||
conn net.Conn // the net.Conn used to establish the connection.
|
||||
R *circ.Reader // a reader for reading incoming bytes.
|
||||
W *circ.Writer // a writer for writing outgoing bytes.
|
||||
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
|
||||
systemInfo *system.Info // pointers to server system info.
|
||||
packetID uint32 // the current highest packetID.
|
||||
keepalive uint16 // the number of seconds the connection can wait.
|
||||
CleanSession bool // indicates if the client expects a clean-session.
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
type State struct {
|
||||
Done int64 // atomic counter which indicates that the client has closed.
|
||||
started *sync.WaitGroup // tracks the goroutines which have been started.
|
||||
endedW *sync.WaitGroup // tracks when the writer has ended.
|
||||
endedR *sync.WaitGroup // tracks when the reader has ended.
|
||||
endOnce sync.Once // only end once.
|
||||
started *sync.WaitGroup // tracks the goroutines which have been started.
|
||||
endedW *sync.WaitGroup // tracks when the writer has ended.
|
||||
endedR *sync.WaitGroup // tracks when the reader has ended.
|
||||
Done uint32 // atomic counter which indicates that the client has closed.
|
||||
endOnce sync.Once // only end once.
|
||||
stopCause atomic.Value // reason for stopping.
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
|
||||
cl := &Client{
|
||||
conn: c,
|
||||
r: r,
|
||||
w: w,
|
||||
system: s,
|
||||
keepalive: defaultKeepalive,
|
||||
Inflight: Inflight{
|
||||
conn: c,
|
||||
R: r,
|
||||
W: w,
|
||||
systemInfo: s,
|
||||
keepalive: defaultKeepalive,
|
||||
Inflight: &Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
Subscriptions: make(map[string]byte),
|
||||
@@ -139,7 +143,7 @@ func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Clie
|
||||
// method is typically called by the persistence restoration system.
|
||||
func NewClientStub(s *system.Info) *Client {
|
||||
return &Client{
|
||||
Inflight: Inflight{
|
||||
Inflight: &Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
Subscriptions: make(map[string]byte),
|
||||
@@ -159,11 +163,11 @@ func (cl *Client) Identify(lid string, pk packets.Packet, ac auth.Controller) {
|
||||
cl.ID = xid.New().String()
|
||||
}
|
||||
|
||||
cl.r.ID = cl.ID + " READER"
|
||||
cl.w.ID = cl.ID + " WRITER"
|
||||
cl.R.ID = cl.ID + " READER"
|
||||
cl.W.ID = cl.ID + " WRITER"
|
||||
|
||||
cl.Username = pk.Username
|
||||
cl.cleanSession = pk.CleanSession
|
||||
cl.CleanSession = pk.CleanSession
|
||||
cl.keepalive = pk.Keepalive
|
||||
|
||||
if pk.WillFlag {
|
||||
@@ -185,7 +189,20 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second)
|
||||
}
|
||||
cl.conn.SetDeadline(expiry)
|
||||
_ = cl.conn.SetDeadline(expiry)
|
||||
}
|
||||
}
|
||||
|
||||
// Info returns an event-version of a client, containing minimal information.
|
||||
func (cl *Client) Info() events.Client {
|
||||
addr := "unknown"
|
||||
if cl.conn != nil && cl.conn.RemoteAddr() != nil {
|
||||
addr = cl.conn.RemoteAddr().String()
|
||||
}
|
||||
return events.Client{
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,47 +235,75 @@ func (cl *Client) ForgetSubscription(filter string) {
|
||||
// Start begins the client goroutines reading and writing packets.
|
||||
func (cl *Client) Start() {
|
||||
cl.State.started.Add(2)
|
||||
|
||||
go func() {
|
||||
cl.State.started.Done()
|
||||
cl.w.WriteTo(cl.conn)
|
||||
cl.State.endedW.Done()
|
||||
cl.Stop()
|
||||
}()
|
||||
cl.State.endedW.Add(1)
|
||||
cl.State.endedR.Add(1)
|
||||
|
||||
go func() {
|
||||
cl.State.started.Done()
|
||||
cl.r.ReadFrom(cl.conn)
|
||||
cl.State.endedR.Done()
|
||||
cl.Stop()
|
||||
_, err := cl.W.WriteTo(cl.conn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("writer: %w", err)
|
||||
}
|
||||
cl.State.endedW.Done()
|
||||
cl.Stop(err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
cl.State.started.Done()
|
||||
_, err := cl.R.ReadFrom(cl.conn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("reader: %w", err)
|
||||
}
|
||||
cl.State.endedR.Done()
|
||||
cl.Stop(err)
|
||||
}()
|
||||
cl.State.endedR.Add(1)
|
||||
|
||||
cl.State.started.Wait()
|
||||
}
|
||||
|
||||
// ClearBuffers sets the read/write buffers to nil so they can be
|
||||
// deallocated automatically when no longer in use.
|
||||
func (cl *Client) ClearBuffers() {
|
||||
cl.R = nil
|
||||
cl.W = nil
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop() {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
// A cause error may be passed to identfy the reason for stopping.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
cl.r.Stop()
|
||||
cl.w.Stop()
|
||||
cl.R.Stop()
|
||||
cl.W.Stop()
|
||||
|
||||
cl.State.endedW.Wait()
|
||||
|
||||
cl.conn.Close()
|
||||
_ = cl.conn.Close() // omit close error
|
||||
|
||||
cl.State.endedR.Wait()
|
||||
atomic.StoreInt64(&cl.State.Done, 1)
|
||||
atomic.StoreUint32(&cl.State.Done, 1)
|
||||
|
||||
if err == nil {
|
||||
err = ErrConnectionClosed
|
||||
}
|
||||
cl.State.stopCause.Store(err)
|
||||
})
|
||||
}
|
||||
|
||||
// readFixedHeader reads in the values of the next packet's fixed header.
|
||||
// StopCause returns the reason the client connection was stopped, if any.
|
||||
func (cl *Client) StopCause() error {
|
||||
if cl.State.stopCause.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
p, err := cl.r.Read(1)
|
||||
p, err := cl.R.Read(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -275,7 +320,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
i := 1
|
||||
n := 2
|
||||
for ; n < 6; n++ {
|
||||
p, err = cl.r.Read(n)
|
||||
p, err = cl.R.Read(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -299,8 +344,8 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
fh.Remaining = int(rem)
|
||||
|
||||
// Having successfully read n bytes, commit the tail forward.
|
||||
cl.r.CommitTail(n)
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
|
||||
cl.R.CommitTail(n)
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -309,7 +354,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
// an error is encountered (or the connection is closed).
|
||||
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
|
||||
for {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.R.CapDelta() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -334,18 +379,18 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.system.MessagesRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
|
||||
|
||||
pk.FixedHeader = *fh
|
||||
if pk.FixedHeader.Remaining == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p, err := cl.r.Read(pk.FixedHeader.Remaining)
|
||||
p, err := cl.R.Read(pk.FixedHeader.Remaining)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
@@ -359,7 +404,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
@@ -384,19 +429,19 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
err = fmt.Errorf("No valid packet available; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
|
||||
cl.r.CommitTail(pk.FixedHeader.Remaining)
|
||||
cl.R.CommitTail(pk.FixedHeader.Remaining)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
cl.w.Mu.Lock()
|
||||
defer cl.w.Mu.Unlock()
|
||||
cl.W.Mu.Lock()
|
||||
defer cl.W.Mu.Unlock()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
switch pk.FixedHeader.Type {
|
||||
@@ -407,7 +452,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishSent, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
@@ -439,12 +484,13 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
}
|
||||
|
||||
// Write the packet bytes to the client byte buffer.
|
||||
n, err = cl.w.Write(buf.Bytes())
|
||||
n, err = cl.W.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.system.MessagesSent, 1)
|
||||
|
||||
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
|
||||
@@ -453,8 +499,8 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
|
||||
// LWT contains the last will and testament details for a client connection.
|
||||
type LWT struct {
|
||||
Topic string // the topic the will message shall be sent to.
|
||||
Message []byte // the message that shall be sent when the client disconnects.
|
||||
Topic string // the topic the will message shall be sent to.
|
||||
Qos byte // the quality of service desired.
|
||||
Retain bool // indicates whether the will message should be retained
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/server/events"
|
||||
"github.com/mochi-co/mqtt/server/internal/circ"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testClientStop = errors.New("test stop")
|
||||
|
||||
func genClient() *Client {
|
||||
c, _ := net.Pipe()
|
||||
return NewClient(c, circ.NewReader(128, 8), circ.NewWriter(128, 8), new(system.Info))
|
||||
@@ -130,11 +133,39 @@ func TestNewClient(t *testing.T) {
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.Inflight.internal)
|
||||
require.NotNil(t, cl.Subscriptions)
|
||||
require.NotNil(t, cl.r)
|
||||
require.NotNil(t, cl.w)
|
||||
require.NotNil(t, cl.State.started)
|
||||
require.NotNil(t, cl.State.endedW)
|
||||
require.NotNil(t, cl.State.endedR)
|
||||
require.NotNil(t, cl.R)
|
||||
require.NotNil(t, cl.W)
|
||||
require.Nil(t, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientInfoUnknown(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.ID = "testid"
|
||||
cl.Listener = "testlistener"
|
||||
cl.conn = nil
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "testid",
|
||||
Remote: "unknown",
|
||||
Listener: "testlistener",
|
||||
}, cl.Info())
|
||||
}
|
||||
|
||||
func TestClientInfoKnown(t *testing.T) {
|
||||
c1, c2 := net.Pipe()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
cl := genClient()
|
||||
cl.ID = "ID"
|
||||
cl.Listener = "L"
|
||||
cl.conn = c1
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "ID",
|
||||
Remote: c1.RemoteAddr().String(),
|
||||
Listener: "L",
|
||||
}, cl.Info())
|
||||
}
|
||||
|
||||
func BenchmarkNewClient(b *testing.B) {
|
||||
@@ -175,7 +206,7 @@ func TestClientIdentify(t *testing.T) {
|
||||
|
||||
cl.Identify("tcp1", pk, new(auth.Allow))
|
||||
require.Equal(t, pk.Keepalive, cl.keepalive)
|
||||
require.Equal(t, pk.CleanSession, cl.cleanSession)
|
||||
require.Equal(t, pk.CleanSession, cl.CleanSession)
|
||||
require.Equal(t, pk.ClientIdentifier, cl.ID)
|
||||
}
|
||||
|
||||
@@ -314,15 +345,15 @@ func BenchmarkClientRefreshDeadline(b *testing.B) {
|
||||
func TestClientStart(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.R.State))
|
||||
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.W.State))
|
||||
}
|
||||
|
||||
func BenchmarkClientStart(b *testing.B) {
|
||||
cl := genClient()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
cl.Start()
|
||||
@@ -332,17 +363,17 @@ func BenchmarkClientStart(b *testing.B) {
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||
cl.r.SetPos(0, 2)
|
||||
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||
cl.R.SetPos(0, 2)
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
|
||||
tail, head := cl.r.GetPos()
|
||||
tail, head := cl.R.GetPos()
|
||||
require.Equal(t, int64(2), tail)
|
||||
require.Equal(t, int64(2), head)
|
||||
|
||||
@@ -351,13 +382,13 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
fh := new(packets.FixedHeader)
|
||||
cl.r.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
|
||||
cl.r.SetPos(0, 2)
|
||||
cl.R.Set([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00}, 0, 2)
|
||||
cl.R.SetPos(0, 2)
|
||||
o <- cl.ReadFixedHeader(fh)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
@@ -367,17 +398,17 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
fh := new(packets.FixedHeader)
|
||||
cl.r.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||
cl.r.SetPos(0, 1)
|
||||
cl.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||
cl.R.SetPos(0, 1)
|
||||
o <- cl.ReadFixedHeader(fh)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
cl.r.Stop()
|
||||
cl.R.Stop()
|
||||
err := <-o
|
||||
require.Error(t, err)
|
||||
require.Equal(t, io.EOF, err)
|
||||
@@ -386,14 +417,14 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.r.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
|
||||
err := cl.R.Set([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 0, 5)
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, 5)
|
||||
cl.R.SetPos(0, 5)
|
||||
o <- cl.ReadFixedHeader(fh)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
@@ -403,7 +434,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
// Two packets in a row...
|
||||
b := []byte{
|
||||
@@ -417,9 +448,9 @@ func TestClientReadOK(t *testing.T) {
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
}
|
||||
|
||||
err := cl.r.Set(b, 0, len(b))
|
||||
err := cl.R.Set(b, 0, len(b))
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, int64(len(b)))
|
||||
cl.R.SetPos(0, int64(len(b)))
|
||||
|
||||
o := make(chan error)
|
||||
var pks []packets.Packet
|
||||
@@ -431,7 +462,7 @@ func TestClientReadOK(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
cl.r.Stop()
|
||||
cl.R.Stop()
|
||||
|
||||
err = <-o
|
||||
require.Error(t, err)
|
||||
@@ -456,15 +487,25 @@ func TestClientReadOK(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
|
||||
|
||||
}
|
||||
|
||||
func TestClientClearBuffers(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
cl.Stop(testClientStop)
|
||||
cl.ClearBuffers()
|
||||
|
||||
require.Nil(t, cl.W)
|
||||
require.Nil(t, cl.R)
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
cl.State.Done = 1
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
@@ -477,7 +518,7 @@ func TestClientReadDone(t *testing.T) {
|
||||
func TestClientReadPacketError(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
b := []byte{
|
||||
0, 18,
|
||||
@@ -485,9 +526,9 @@ func TestClientReadPacketError(t *testing.T) {
|
||||
'a', '/', 'b', '/', 'c',
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i',
|
||||
}
|
||||
err := cl.r.Set(b, 0, len(b))
|
||||
err := cl.R.Set(b, 0, len(b))
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, int64(len(b)))
|
||||
cl.R.SetPos(0, int64(len(b)))
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
@@ -499,10 +540,37 @@ func TestClientReadPacketError(t *testing.T) {
|
||||
require.Error(t, <-o)
|
||||
}
|
||||
|
||||
func TestClientReadPacketEOF(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
|
||||
b := []byte{
|
||||
0, 18,
|
||||
0, 5,
|
||||
'a', '/', 'b', '/', 'c',
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', // missing 1 byte
|
||||
}
|
||||
err := cl.R.Set(b, 0, len(b))
|
||||
require.NoError(t, err)
|
||||
cl.R.SetPos(0, int64(len(b)))
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
cl.R.Stop()
|
||||
cl.Stop(testClientStop)
|
||||
require.Error(t, <-o)
|
||||
require.True(t, errors.Is(cl.StopCause(), testClientStop))
|
||||
}
|
||||
|
||||
func TestClientReadHandlerErr(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
b := []byte{
|
||||
byte(packets.Publish << 4), 11, // Fixed header
|
||||
@@ -511,9 +579,9 @@ func TestClientReadHandlerErr(t *testing.T) {
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
}
|
||||
|
||||
err := cl.r.Set(b, 0, len(b))
|
||||
err := cl.R.Set(b, 0, len(b))
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, int64(len(b)))
|
||||
cl.R.SetPos(0, int64(len(b)))
|
||||
|
||||
err = cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return errors.New("test")
|
||||
@@ -525,16 +593,16 @@ func TestClientReadHandlerErr(t *testing.T) {
|
||||
func TestClientReadPacketOK(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
err := cl.r.Set([]byte{
|
||||
err := cl.R.Set([]byte{
|
||||
byte(packets.Publish << 4), 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
}, 0, 13)
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, 13)
|
||||
cl.R.SetPos(0, 13)
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
@@ -557,12 +625,12 @@ func TestClientReadPacketOK(t *testing.T) {
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
for i, tt := range pkTable {
|
||||
err := cl.r.Set(tt.bytes, 0, len(tt.bytes))
|
||||
err := cl.R.Set(tt.bytes, 0, len(tt.bytes))
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(0, int64(len(tt.bytes)))
|
||||
cl.R.SetPos(0, int64(len(tt.bytes)))
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
@@ -574,7 +642,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -582,16 +650,16 @@ func TestClientReadPacket(t *testing.T) {
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
|
||||
err := cl.r.Set([]byte{
|
||||
err := cl.R.Set([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
}, 0, 13)
|
||||
require.NoError(t, err)
|
||||
cl.r.SetPos(2, 13)
|
||||
cl.R.SetPos(2, 13)
|
||||
|
||||
_, err = cl.ReadPacket(&packets.FixedHeader{
|
||||
Type: 0,
|
||||
@@ -603,8 +671,8 @@ func TestClientReadPacketReadingError(t *testing.T) {
|
||||
func TestClientReadPacketReadError(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
cl.r.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
cl.R.Stop()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Remaining: 1,
|
||||
@@ -616,8 +684,8 @@ func TestClientReadPacketReadError(t *testing.T) {
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl := genClient()
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
cl.r.Stop()
|
||||
defer cl.Stop(testClientStop)
|
||||
cl.R.Stop()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Remaining: 1,
|
||||
@@ -646,11 +714,21 @@ func TestClientWritePacket(t *testing.T) {
|
||||
r.Close()
|
||||
|
||||
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
cl.Stop()
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
|
||||
cl.Stop(testClientStop)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
time.Sleep(time.Millisecond * 5)
|
||||
require.True(t,
|
||||
errors.Is(err, testClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, io.ErrClosedPipe))
|
||||
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -658,8 +736,8 @@ func TestClientWritePacket(t *testing.T) {
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
c, _ := net.Pipe()
|
||||
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
|
||||
cl.w.SetPos(0, 16)
|
||||
cl.Stop()
|
||||
cl.W.SetPos(0, 16)
|
||||
cl.Stop(testClientStop)
|
||||
|
||||
_, err := cl.WritePacket(pkTable[1].packet)
|
||||
require.Error(t, err)
|
||||
@@ -669,8 +747,8 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
c, _ := net.Pipe()
|
||||
cl := NewClient(c, circ.NewReader(16, 4), circ.NewWriter(16, 4), new(system.Info))
|
||||
cl.w.SetPos(0, 16)
|
||||
cl.w.Stop()
|
||||
cl.W.SetPos(0, 16)
|
||||
cl.W.Stop()
|
||||
|
||||
_, err := cl.WritePacket(pkTable[1].packet)
|
||||
require.Error(t, err)
|
||||
@@ -729,7 +807,7 @@ func TestInflightGetAll(t *testing.T) {
|
||||
|
||||
m := cl.Inflight.GetAll()
|
||||
o := map[uint16]InflightMessage{
|
||||
2: InflightMessage{},
|
||||
2: {},
|
||||
}
|
||||
require.Equal(t, o, m)
|
||||
}
|
||||
|
@@ -28,6 +28,10 @@ func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) {
|
||||
return "", 0, ErrOffsetStrInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
@@ -39,12 +43,10 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange
|
||||
return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
if !validUTF8(buf[next : next+int(length)]) {
|
||||
return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8
|
||||
}
|
||||
// Note: there is no validUTF8() test for []byte payloads
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package packets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -19,15 +21,16 @@ func BenchmarkBytesToString(b *testing.B) {
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result []string
|
||||
result string
|
||||
offset int
|
||||
shouldFail bool
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: []string{"a/b/c/d", "a"},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
@@ -41,36 +44,32 @@ func TestDecodeString(t *testing.T) {
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: []string{"hey"},
|
||||
result: "hey",
|
||||
},
|
||||
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: []string{"x/y/z", "!@#$%^&"},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
result: []string{"a/b/c/d", "z"},
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
result: []string{"a/b/c/d", "x"},
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
result: []string{"a/b/c/d", "y"},
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
@@ -86,20 +85,26 @@ func TestDecodeString(t *testing.T) {
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
result: []string{"lwt"},
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail {
|
||||
require.Error(t, err, "Expected error decoding string [i:%d]", i)
|
||||
continue
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Error decoding string [i:%d]", i)
|
||||
require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i)
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,7 +121,7 @@ func TestDecodeBytes(t *testing.T) {
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail bool
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session)
|
||||
@@ -132,33 +137,27 @@ func TestDecodeBytes(t *testing.T) {
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
offset: 0,
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
offset: 8,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}),
|
||||
offset: 0,
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail {
|
||||
require.Error(t, err, "Expected error decoding bytes [i:%d]", i)
|
||||
continue
|
||||
}
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Error decoding bytes [i:%d]", i)
|
||||
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +173,7 @@ func TestDecodeByte(t *testing.T) {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail bool
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
@@ -198,22 +197,23 @@ func TestDecodeByte(t *testing.T) {
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
result: uint8(0x00),
|
||||
offset: 8,
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail {
|
||||
require.Error(t, err, "Expected error decoding byte [i:%d]", i)
|
||||
continue
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Error decoding byte [i:%d]", i)
|
||||
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
|
||||
require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i)
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestDecodeUint16(t *testing.T) {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail bool
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
@@ -243,22 +243,24 @@ func TestDecodeUint16(t *testing.T) {
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
result: uint16(0x761),
|
||||
offset: 8,
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail {
|
||||
require.Error(t, err, "Expected error decoding uint16 [i:%d]", i)
|
||||
continue
|
||||
}
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err, "Error decoding uint16 [i:%d]", i)
|
||||
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
|
||||
require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,7 +276,7 @@ func TestDecodeByteBool(t *testing.T) {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail bool
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
@@ -287,20 +289,22 @@ func TestDecodeByteBool(t *testing.T) {
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: true,
|
||||
shouldFail: ErrOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail {
|
||||
require.Error(t, err, "Expected error decoding byte bool [i:%d]", i)
|
||||
continue
|
||||
}
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "Error decoding byte bool [i:%d]", i)
|
||||
require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i)
|
||||
require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,20 +6,20 @@ import (
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Retain bool // whether the message should be retained.
|
||||
Remaining int // the number of remaining bytes in the payload.
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, fh.Remaining)
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// decode extracts the specification bits from the header byte.
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(buf *bytes.Buffer, length int) {
|
||||
func encodeLength(buf *bytes.Buffer, length int64) {
|
||||
for {
|
||||
digit := byte(length % 128)
|
||||
length /= 128
|
||||
|
@@ -18,130 +18,130 @@ type fixedHeaderTable struct {
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Connack, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Publish, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Publish, false, 1, false, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Publish, false, 1, true, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Publish, false, 2, false, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Publish, false, 2, true, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Publish, true, 0, false, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Publish, true, 0, true, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Publish, true, 1, true, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Publish, true, 2, true, 0},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Puback, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Pubrec, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Pubrel, false, 1, false, 0},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Pubcomp, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Subscribe, false, 1, false, 0},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Suback, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Unsubscribe, false, 1, false, 0},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Unsuback, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Pingreq, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Pingresp, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Disconnect, false, 0, false, 0},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Publish, false, 0, false, 10},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Publish, false, 0, false, 512},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Publish, false, 0, false, 978},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Publish, false, 0, false, 20102},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Publish, false, 0, false, 333333333},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Connect, true, 0, false, 0},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Connect, false, 1, false, 0},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Connect, false, 0, true, 0},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
flagError: true,
|
||||
},
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
tt := []struct {
|
||||
have int
|
||||
have int64
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
|
@@ -3,6 +3,8 @@ package packets
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
@@ -60,7 +62,6 @@ var (
|
||||
|
||||
// PACKETS
|
||||
ErrProtocolViolation = errors.New("protocol violation")
|
||||
ErrOffsetStrOutOfRange = errors.New("offset string out of range")
|
||||
ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range")
|
||||
ErrOffsetByteOutOfRange = errors.New("offset byte out of range")
|
||||
ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")
|
||||
@@ -76,40 +77,31 @@ var (
|
||||
// packet structs, this is a single concrete packet type to cover all packet
|
||||
// types, which allows us to take advantage of various compiler optimizations.
|
||||
type Packet struct {
|
||||
FixedHeader FixedHeader
|
||||
|
||||
PacketID uint16
|
||||
|
||||
// Connect
|
||||
FixedHeader FixedHeader
|
||||
AllowClients []string // For use with OnMessage event hook.
|
||||
Topics []string
|
||||
ReturnCodes []byte
|
||||
ProtocolName []byte
|
||||
Qoss []byte
|
||||
Payload []byte
|
||||
Username []byte
|
||||
Password []byte
|
||||
WillMessage []byte
|
||||
ClientIdentifier string
|
||||
TopicName string
|
||||
WillTopic string
|
||||
PacketID uint16
|
||||
Keepalive uint16
|
||||
ReturnCode byte
|
||||
ProtocolVersion byte
|
||||
WillQos byte
|
||||
ReservedBit byte
|
||||
CleanSession bool
|
||||
WillFlag bool
|
||||
WillQos byte
|
||||
WillRetain bool
|
||||
UsernameFlag bool
|
||||
PasswordFlag bool
|
||||
ReservedBit byte
|
||||
Keepalive uint16
|
||||
ClientIdentifier string
|
||||
WillTopic string
|
||||
WillMessage []byte
|
||||
Username []byte
|
||||
Password []byte
|
||||
|
||||
// Connack
|
||||
SessionPresent bool
|
||||
ReturnCode byte
|
||||
|
||||
// Publish
|
||||
TopicName string
|
||||
Payload []byte
|
||||
|
||||
// Subscribe, Unsubscribe
|
||||
Topics []string
|
||||
Qoss []byte
|
||||
|
||||
ReturnCodes []byte // Suback
|
||||
SessionPresent bool
|
||||
}
|
||||
|
||||
// ConnectEncode encodes a connect packet.
|
||||
@@ -169,17 +161,17 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
// Unpack protocol name and version.
|
||||
pk.ProtocolName, offset, err = decodeBytes(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedProtocolName
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolName)
|
||||
}
|
||||
|
||||
pk.ProtocolVersion, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedProtocolVersion
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProtocolVersion)
|
||||
}
|
||||
// Unpack flags byte.
|
||||
flags, offset, err := decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedFlags
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedFlags)
|
||||
}
|
||||
pk.ReservedBit = 1 & flags
|
||||
pk.CleanSession = 1&(flags>>1) > 0
|
||||
@@ -192,25 +184,25 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
// Get keepalive interval.
|
||||
pk.Keepalive, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedKeepalive
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedKeepalive)
|
||||
}
|
||||
|
||||
// Get client ID.
|
||||
pk.ClientIdentifier, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedClientID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedClientID)
|
||||
}
|
||||
|
||||
// Get Last Will and Testament topic and message if applicable.
|
||||
if pk.WillFlag {
|
||||
pk.WillTopic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedWillTopic
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillTopic)
|
||||
}
|
||||
|
||||
pk.WillMessage, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedWillMessage
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedWillMessage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,14 +210,14 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
if pk.UsernameFlag {
|
||||
pk.Username, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedUsername
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedUsername)
|
||||
}
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedPassword
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,12 +284,12 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
|
||||
pk.SessionPresent, offset, err = decodeByteBool(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedSessionPresent
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedReturnCode
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -334,7 +326,7 @@ func (pk *Packet) PubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -352,7 +344,7 @@ func (pk *Packet) PubcompDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -390,14 +382,14 @@ func (pk *Packet) PublishDecode(buf []byte) error {
|
||||
|
||||
pk.TopicName, offset, err = decodeString(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedTopic
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
// If QOS decode Packet ID.
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,7 +443,7 @@ func (pk *Packet) PubrecDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -470,7 +462,7 @@ func (pk *Packet) PubrelDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -495,7 +487,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
|
||||
// Get Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Get Granted QOS flags.
|
||||
@@ -542,7 +534,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
@@ -552,7 +544,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
var topic string
|
||||
topic, offset, err = decodeString(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedTopic
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
pk.Topics = append(pk.Topics, topic)
|
||||
|
||||
@@ -560,7 +552,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
var qos byte
|
||||
qos, offset, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedQoS
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedQoS)
|
||||
}
|
||||
|
||||
// Ensure QoS byte is within range.
|
||||
@@ -599,7 +591,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
|
||||
var err error
|
||||
pk.PacketID, _, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -641,7 +633,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
// Get the Packet ID.
|
||||
pk.PacketID, offset, err = decodeUint16(buf, 0)
|
||||
if err != nil {
|
||||
return ErrMalformedPacketID
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPacketID)
|
||||
}
|
||||
|
||||
// Keep decoding until there's no space left.
|
||||
@@ -649,7 +641,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
var t string
|
||||
t, offset, err = decodeString(buf, offset) // Decode Topic Name.
|
||||
if err != nil {
|
||||
return ErrMalformedTopic
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedTopic)
|
||||
}
|
||||
|
||||
if len(t) > 0 {
|
||||
@@ -671,3 +663,8 @@ func (pk *Packet) UnsubscribeValidate() (byte, error) {
|
||||
|
||||
return Accepted, nil
|
||||
}
|
||||
|
||||
// FormatID returns the PacketID field as a decimal integer.
|
||||
func (pk *Packet) FormatID() string {
|
||||
return strconv.FormatUint(uint64(pk.PacketID), 10)
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ type packetTestData struct {
|
||||
actualBytes []byte // the actual byte array that is created in the event of a byte mutation (eg. MQTT-2.3.1-1 qos/packet id)
|
||||
packet *Packet // the packet that is expected
|
||||
desc string // a description of the test
|
||||
failFirst interface{} // expected fail result to be run immediately after the method is called
|
||||
failFirst error // expected fail result to be run immediately after the method is called
|
||||
expect interface{} // generic expected fail result to be checked
|
||||
isolate bool // isolate can be used to isolate a test
|
||||
primary bool // primary is a test that should be run using readPackets
|
||||
|
@@ -2,6 +2,8 @@ package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/copier"
|
||||
@@ -74,13 +76,13 @@ func TestConnectDecode(t *testing.T) {
|
||||
}
|
||||
|
||||
require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficient bytes in packet [i:%d] %s", i, wanted.desc)
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: Connect}}
|
||||
err := pk.ConnectDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -180,7 +182,7 @@ func TestConnackDecode(t *testing.T) {
|
||||
err := pk.ConnackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -339,7 +341,7 @@ func TestPubackDecode(t *testing.T) {
|
||||
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -407,7 +409,7 @@ func TestPubcompDecode(t *testing.T) {
|
||||
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -494,7 +496,7 @@ func TestPublishDecode(t *testing.T) {
|
||||
err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected fh error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -630,7 +632,7 @@ func TestPubrecDecode(t *testing.T) {
|
||||
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -703,7 +705,7 @@ func TestPubrelDecode(t *testing.T) {
|
||||
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -775,7 +777,7 @@ func TestSubackDecode(t *testing.T) {
|
||||
err := pk.SubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -854,7 +856,7 @@ func TestSubscribeDecode(t *testing.T) {
|
||||
err := pk.SubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -957,7 +959,7 @@ func TestUnsubackDecode(t *testing.T) {
|
||||
err := pk.UnsubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1034,7 +1036,7 @@ func TestUnsubscribeDecode(t *testing.T) {
|
||||
err := pk.UnsubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
|
||||
if wanted.failFirst != nil {
|
||||
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
|
||||
require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
require.True(t, errors.Is(err, wanted.failFirst), "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1080,3 +1082,10 @@ func BenchmarkUnsubscribeValidate(b *testing.B) {
|
||||
pk.UnsubscribeValidate()
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatPacketID(t *testing.T) {
|
||||
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
|
||||
packet := &Packet{PacketID: id}
|
||||
require.Equal(t, fmt.Sprint(id), packet.FormatID())
|
||||
}
|
||||
}
|
||||
|
@@ -27,27 +27,28 @@ func New() *Index {
|
||||
}
|
||||
|
||||
// RetainMessage saves a message payload to the end of a topic branch. Returns
|
||||
// 1 if a retained message was added, 0 if there was no change, and -1 if the
|
||||
// retained message was removed.
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *Index) RetainMessage(msg packets.Packet) int64 {
|
||||
var q int64
|
||||
|
||||
x.mu.Lock()
|
||||
defer x.mu.Unlock()
|
||||
n := x.poperate(msg.TopicName)
|
||||
|
||||
// If there is a payload, we can store it.
|
||||
if len(msg.Payload) > 0 {
|
||||
if n.Message.FixedHeader.Retain == false {
|
||||
q = 1
|
||||
}
|
||||
n.Message = msg
|
||||
} else {
|
||||
if n.Message.FixedHeader.Retain == true {
|
||||
q = -1
|
||||
}
|
||||
x.unpoperate(msg.TopicName, "", true)
|
||||
return 1
|
||||
}
|
||||
|
||||
return q
|
||||
// Otherwise, we are unsetting it.
|
||||
// If there was a previous retained message, return -1 instead of 0.
|
||||
var r int64 = 0
|
||||
if len(n.Message.Payload) > 0 && n.Message.FixedHeader.Retain == true {
|
||||
r = -1
|
||||
}
|
||||
x.unpoperate(msg.TopicName, "", true)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription filter for a client. Returns true if the
|
||||
@@ -175,12 +176,12 @@ func (x *Index) Messages(filter string) []packets.Packet {
|
||||
|
||||
// Leaf is a child node on the tree.
|
||||
type Leaf struct {
|
||||
Message packets.Packet // a message which has been retained for a specific topic.
|
||||
Key string // the key that was used to create the leaf.
|
||||
Filter string // the path of the topic filter being matched.
|
||||
Parent *Leaf // a pointer to the parent node for the leaf.
|
||||
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
|
||||
Clients map[string]byte // a map of client ids subscribed to the topic.
|
||||
Filter string // the path of the topic filter being matched.
|
||||
Message packets.Packet // a message which has been retained for a specific topic.
|
||||
}
|
||||
|
||||
// scanSubscribers recursively steps through a branch of leaves finding clients who
|
||||
|
@@ -113,8 +113,10 @@ func TestRetainMessage(t *testing.T) {
|
||||
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
|
||||
q = index.RetainMessage(pk2) // already exsiting
|
||||
require.Equal(t, int64(0), q)
|
||||
// The same message already exists, but we're not doing a deep-copy check, so it's considered
|
||||
// to be a new message.
|
||||
q = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"])
|
||||
require.Equal(t, pk2, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message)
|
||||
require.Contains(t, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Clients, "client-1")
|
||||
@@ -126,6 +128,14 @@ func TestRetainMessage(t *testing.T) {
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
|
||||
|
||||
// Second Delete retained
|
||||
q = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(0), q)
|
||||
require.NotNil(t, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"])
|
||||
require.Equal(t, pk, index.Root.Leaves["path"].Leaves["to"].Leaves["my"].Leaves["mqtt"].Message)
|
||||
require.Equal(t, false, index.Root.Leaves["path"].Leaves["to"].Leaves["another"].Leaves["mqtt"].Message.FixedHeader.Retain)
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkRetainMessage(b *testing.B) {
|
||||
|
14
server/internal/utils/utils.go
Normal file
14
server/internal/utils/utils.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
// InSliceString returns true if a string exists in a slice of strings.
|
||||
// This temporary and should be replaced with a function from the new
|
||||
// go slices package in 1.19 when available.
|
||||
// https://github.com/golang/go/issues/45955
|
||||
func InSliceString(sl []string, st string) bool {
|
||||
for _, v := range sl {
|
||||
if st == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
18
server/internal/utils/utils_test.go
Normal file
18
server/internal/utils/utils_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInSliceString(t *testing.T) {
|
||||
sl := []string{"a", "b", "c"}
|
||||
require.Equal(t, true, InSliceString(sl, "b"))
|
||||
|
||||
sl = []string{"a", "a", "a"}
|
||||
require.Equal(t, true, InSliceString(sl, "a"))
|
||||
|
||||
sl = []string{"a", "b", "c"}
|
||||
require.Equal(t, false, InSliceString(sl, "d"))
|
||||
}
|
@@ -3,7 +3,7 @@ package auth
|
||||
// Allow is an auth controller which allows access to all connections and topics.
|
||||
type Allow struct{}
|
||||
|
||||
// Auth returns true if a username and password are acceptable. Allow always
|
||||
// Authenticate returns true if a username and password are acceptable. Allow always
|
||||
// returns true.
|
||||
func (a *Allow) Authenticate(user, password []byte) bool {
|
||||
return true
|
||||
@@ -18,7 +18,7 @@ func (a *Allow) ACL(user []byte, topic string, write bool) bool {
|
||||
// Disallow is an auth controller which disallows access to all connections and topics.
|
||||
type Disallow struct{}
|
||||
|
||||
// Auth returns true if a username and password are acceptable. Disallow always
|
||||
// Authenticate returns true if a username and password are acceptable. Disallow always
|
||||
// returns false.
|
||||
func (d *Disallow) Authenticate(user, password []byte) bool {
|
||||
return false
|
||||
|
@@ -18,11 +18,11 @@ import (
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
config *Config // configuration values for the listener.
|
||||
system *system.Info // pointers to the server data.
|
||||
address string // the network address to bind to.
|
||||
listen *http.Server // the http server.
|
||||
end int64 // ensure the close methods are only called once.}
|
||||
end uint32 // ensure the close methods are only called once.}
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
@@ -38,10 +38,10 @@ type Listener interface {
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
sync.RWMutex
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
system *system.Info // pointers to system info.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Listeners.
|
||||
|
@@ -22,11 +22,11 @@ func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
|
||||
type MockListener struct {
|
||||
sync.RWMutex
|
||||
id string // the id of the listener.
|
||||
Config *Config // configuration for the listener.
|
||||
address string // the network address the listener binds to.
|
||||
Listening bool // indiciate the listener is listening.
|
||||
Serving bool // indicate the listener is serving.
|
||||
Config *Config // configuration for the listener.
|
||||
done chan bool // indicate the listener is done.
|
||||
Serving bool // indicate the listener is serving.
|
||||
Listening bool // indiciate the listener is listening.
|
||||
ErrListen bool // throw an error on listen.
|
||||
}
|
||||
|
||||
@@ -44,15 +44,12 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
for {
|
||||
select {
|
||||
case <-l.done:
|
||||
return
|
||||
}
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// SetConfig sets the configuration values of the mock listener.
|
||||
// Listen begins listening for incoming traffic.
|
||||
func (l *MockListener) Listen(s *system.Info) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
@@ -95,7 +92,7 @@ func (l *MockListener) IsServing() bool {
|
||||
return l.Serving
|
||||
}
|
||||
|
||||
// IsServing indicates whether the mock listener is listening.
|
||||
// IsListening indicates whether the mock listener is listening.
|
||||
func (l *MockListener) IsListening() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
@@ -14,11 +14,11 @@ import (
|
||||
type TCP struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
config *Config // configuration values for the listener.
|
||||
protocol string // the TCP protocol to use.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
config *Config // configuration values for the listener.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFunc) {
|
||||
for {
|
||||
if atomic.LoadInt64(&l.end) == 1 {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,8 +95,10 @@ func (l *TCP) Serve(establish EstablishFunc) {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
go establish(l.id, conn, l.config.Auth)
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
_ = establish(l.id, conn, l.config.Auth)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,8 +108,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
|
@@ -17,12 +17,14 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessage = errors.New("Message type not binary")
|
||||
// ErrInvalidMessage indicates that a message payload was not valid.
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
|
||||
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
|
||||
// websocket compliant connection.
|
||||
wsUpgrader = &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
@@ -30,11 +32,11 @@ var (
|
||||
type Websocket struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
config *Config // configuration values for the listener.
|
||||
address string // the network address to bind to.
|
||||
config *Config // configuration values for the listener.
|
||||
listen *http.Server // an http server for serving websocket connections.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
establish EstablishFunc // the server's establish conection handler.
|
||||
establish EstablishFunc // the server's establish connection handler.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
||||
@@ -161,8 +163,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
sgob "github.com/asdine/storm/codec/gob"
|
||||
@@ -12,12 +12,17 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPath = "mochi.db"
|
||||
|
||||
// defaultPath is the default file to use to store the data.
|
||||
defaultPath = "mochi.db"
|
||||
|
||||
// defaultTimeout is the default timeout of the file lock.
|
||||
defaultTimeout = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
errNotFound = "not found"
|
||||
// ErrDBNotOpen indicates the bolt db file is not open for reading.
|
||||
ErrDBNotOpen = fmt.Errorf("boltdb not opened")
|
||||
)
|
||||
|
||||
// Store is a backend for writing and reading to bolt persistent storage.
|
||||
@@ -64,7 +69,7 @@ func (s *Store) Close() {
|
||||
// WriteServerInfo writes the server info to the boltdb instance.
|
||||
func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.Save(&v)
|
||||
@@ -77,7 +82,7 @@ func (s *Store) WriteServerInfo(v persistence.ServerInfo) error {
|
||||
// WriteSubscription writes a single subscription to the boltdb instance.
|
||||
func (s *Store) WriteSubscription(v persistence.Subscription) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.Save(&v)
|
||||
@@ -90,7 +95,7 @@ func (s *Store) WriteSubscription(v persistence.Subscription) error {
|
||||
// WriteInflight writes a single inflight message to the boltdb instance.
|
||||
func (s *Store) WriteInflight(v persistence.Message) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.Save(&v)
|
||||
@@ -103,7 +108,7 @@ func (s *Store) WriteInflight(v persistence.Message) error {
|
||||
// WriteRetained writes a single retained message to the boltdb instance.
|
||||
func (s *Store) WriteRetained(v persistence.Message) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.Save(&v)
|
||||
@@ -116,7 +121,7 @@ func (s *Store) WriteRetained(v persistence.Message) error {
|
||||
// WriteClient writes a single client to the boltdb instance.
|
||||
func (s *Store) WriteClient(v persistence.Client) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.Save(&v)
|
||||
@@ -129,7 +134,7 @@ func (s *Store) WriteClient(v persistence.Client) error {
|
||||
// DeleteSubscription deletes a subscription from the boltdb instance.
|
||||
func (s *Store) DeleteSubscription(id string) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.DeleteStruct(&persistence.Subscription{
|
||||
@@ -145,7 +150,7 @@ func (s *Store) DeleteSubscription(id string) error {
|
||||
// DeleteClient deletes a client from the boltdb instance.
|
||||
func (s *Store) DeleteClient(id string) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.DeleteStruct(&persistence.Client{
|
||||
@@ -161,7 +166,7 @@ func (s *Store) DeleteClient(id string) error {
|
||||
// DeleteInflight deletes an inflight message from the boltdb instance.
|
||||
func (s *Store) DeleteInflight(id string) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.DeleteStruct(&persistence.Message{
|
||||
@@ -177,7 +182,7 @@ func (s *Store) DeleteInflight(id string) error {
|
||||
// DeleteRetained deletes a retained message from the boltdb instance.
|
||||
func (s *Store) DeleteRetained(id string) error {
|
||||
if s.db == nil {
|
||||
return errors.New("boltdb not opened")
|
||||
return ErrDBNotOpen
|
||||
}
|
||||
|
||||
err := s.db.DeleteStruct(&persistence.Message{
|
||||
@@ -193,11 +198,11 @@ func (s *Store) DeleteRetained(id string) error {
|
||||
// ReadSubscriptions loads all the subscriptions from the boltdb instance.
|
||||
func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
|
||||
if s.db == nil {
|
||||
return v, errors.New("boltdb not opened")
|
||||
return v, ErrDBNotOpen
|
||||
}
|
||||
|
||||
err = s.db.Find("T", persistence.KSubscription, &v)
|
||||
if err != nil && err.Error() != errNotFound {
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,11 +212,11 @@ func (s *Store) ReadSubscriptions() (v []persistence.Subscription, err error) {
|
||||
// ReadClients loads all the clients from the boltdb instance.
|
||||
func (s *Store) ReadClients() (v []persistence.Client, err error) {
|
||||
if s.db == nil {
|
||||
return v, errors.New("boltdb not opened")
|
||||
return v, ErrDBNotOpen
|
||||
}
|
||||
|
||||
err = s.db.Find("T", persistence.KClient, &v)
|
||||
if err != nil && err.Error() != errNotFound {
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -221,11 +226,11 @@ func (s *Store) ReadClients() (v []persistence.Client, err error) {
|
||||
// ReadInflight loads all the inflight messages from the boltdb instance.
|
||||
func (s *Store) ReadInflight() (v []persistence.Message, err error) {
|
||||
if s.db == nil {
|
||||
return v, errors.New("boltdb not opened")
|
||||
return v, ErrDBNotOpen
|
||||
}
|
||||
|
||||
err = s.db.Find("T", persistence.KInflight, &v)
|
||||
if err != nil && err.Error() != errNotFound {
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -235,11 +240,11 @@ func (s *Store) ReadInflight() (v []persistence.Message, err error) {
|
||||
// ReadRetained loads all the retained messages from the boltdb instance.
|
||||
func (s *Store) ReadRetained() (v []persistence.Message, err error) {
|
||||
if s.db == nil {
|
||||
return v, errors.New("boltdb not opened")
|
||||
return v, ErrDBNotOpen
|
||||
}
|
||||
|
||||
err = s.db.Find("T", persistence.KRetained, &v)
|
||||
if err != nil && err.Error() != errNotFound {
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -249,11 +254,11 @@ func (s *Store) ReadRetained() (v []persistence.Message, err error) {
|
||||
//ReadServerInfo loads the server info from the boltdb instance.
|
||||
func (s *Store) ReadServerInfo() (v persistence.ServerInfo, err error) {
|
||||
if s.db == nil {
|
||||
return v, errors.New("boltdb not opened")
|
||||
return v, ErrDBNotOpen
|
||||
}
|
||||
|
||||
err = s.db.One("ID", persistence.KServerInfo, &v)
|
||||
if err != nil && err.Error() != errNotFound {
|
||||
if err != nil && err != storm.ErrNotFound {
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -7,11 +7,21 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// KSubscription is the key for subscription data.
|
||||
KSubscription = "sub"
|
||||
KServerInfo = "srv"
|
||||
KRetained = "ret"
|
||||
KInflight = "ifm"
|
||||
KClient = "cl"
|
||||
|
||||
// KServerInfo is the key for server info data.
|
||||
KServerInfo = "srv"
|
||||
|
||||
// KRetained is the key for retained messages data.
|
||||
KRetained = "ret"
|
||||
|
||||
// KInflight is the key for inflight messages data.
|
||||
KInflight = "ifm"
|
||||
|
||||
// KClient is the key for client data.
|
||||
KClient = "cl"
|
||||
)
|
||||
|
||||
// Store is an interface which details a persistent storage connector.
|
||||
@@ -53,50 +63,50 @@ type Subscription struct {
|
||||
|
||||
// Message contains the details of a retained or inflight message.
|
||||
type Message struct {
|
||||
ID string // the storage key.
|
||||
T string // the type of the stored data.
|
||||
Client string // the id of the client who sent the message (if inflight).
|
||||
FixedHeader FixedHeader // the header properties of the message.
|
||||
PacketID uint16 // the unique id of the packet (if inflight).
|
||||
TopicName string // the topic the message was sent to (if retained).
|
||||
Payload []byte // the message payload (if retained).
|
||||
FixedHeader FixedHeader // the header properties of the message.
|
||||
T string // the type of the stored data.
|
||||
ID string // the storage key.
|
||||
Client string // the id of the client who sent the message (if inflight).
|
||||
TopicName string // the topic the message was sent to (if retained).
|
||||
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
|
||||
Resends int // the number of times the message was attempted to be sent (if inflight).
|
||||
PacketID uint16 // the unique id of the packet (if inflight).
|
||||
}
|
||||
|
||||
// FixedHeader contains the fixed header properties of a message.
|
||||
type FixedHeader struct {
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Retain bool // whether the message should be retained.
|
||||
Remaining int // the number of remaining bytes in the payload.
|
||||
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte // indicates the quality of service expected.
|
||||
Dup bool // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Client contains client data that can be persistently stored.
|
||||
type Client struct {
|
||||
LWT LWT // the last-will-and-testament message for the client.
|
||||
Username []byte // the username the client authenticated with.
|
||||
ID string // the storage key.
|
||||
ClientID string // the id of the client.
|
||||
T string // the type of the stored data.
|
||||
Listener string // the last known listener id for the client
|
||||
Username []byte // the username the client authenticated with.
|
||||
LWT LWT // the last-will-and-testament message for the client.
|
||||
}
|
||||
|
||||
// LWT contains details about a clients LWT payload.
|
||||
type LWT struct {
|
||||
Topic string // the topic the will message shall be sent to.
|
||||
Message []byte // the message that shall be sent when the client disconnects.
|
||||
Topic string // the topic the will message shall be sent to.
|
||||
Qos byte // the quality of service desired.
|
||||
Retain bool // indicates whether the will message should be retained
|
||||
}
|
||||
|
||||
// MockStore is a mock storage backend for testing.
|
||||
type MockStore struct {
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
FailOpen bool // error on open.
|
||||
Closed bool // indicate mock store is closed.
|
||||
Opened bool // indicate mock store is open.
|
||||
Fail map[string]bool // issue errors for different methods.
|
||||
}
|
||||
|
||||
// Open opens the storage instance.
|
||||
@@ -197,7 +207,7 @@ func (s *MockStore) ReadSubscriptions() (v []Subscription, err error) {
|
||||
}
|
||||
|
||||
return []Subscription{
|
||||
Subscription{
|
||||
{
|
||||
ID: "test:a/b/c",
|
||||
Client: "test",
|
||||
Filter: "a/b/c",
|
||||
@@ -214,7 +224,7 @@ func (s *MockStore) ReadClients() (v []Client, err error) {
|
||||
}
|
||||
|
||||
return []Client{
|
||||
Client{
|
||||
{
|
||||
ID: "cl_client1",
|
||||
ClientID: "client1",
|
||||
T: KClient,
|
||||
@@ -230,7 +240,7 @@ func (s *MockStore) ReadInflight() (v []Message, err error) {
|
||||
}
|
||||
|
||||
return []Message{
|
||||
Message{
|
||||
{
|
||||
ID: "client1_if_100",
|
||||
T: KInflight,
|
||||
Client: "client1",
|
||||
@@ -250,7 +260,7 @@ func (s *MockStore) ReadRetained() (v []Message, err error) {
|
||||
}
|
||||
|
||||
return []Message{
|
||||
Message{
|
||||
{
|
||||
ID: "client1_ret_200",
|
||||
T: KRetained,
|
||||
FixedHeader: FixedHeader{
|
||||
|
463
server/server.go
463
server/server.go
@@ -1,9 +1,10 @@
|
||||
// packet server provides a MQTT 3.1.1 compliant MQTT server.
|
||||
// package server provides a MQTT 3.1.1 compliant MQTT server.
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/mochi-co/mqtt/server/internal/clients"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/internal/topics"
|
||||
"github.com/mochi-co/mqtt/server/internal/utils"
|
||||
"github.com/mochi-co/mqtt/server/listeners"
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/persistence"
|
||||
@@ -21,14 +23,41 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.2" // the server version.
|
||||
// Version indicates the current server version.
|
||||
Version = "1.1.1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrListenerIDExists = errors.New("Listener id already exists")
|
||||
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
|
||||
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
|
||||
ErrInvalidTopic = errors.New("Cannot publish to $ and $SYS topics")
|
||||
// ErrListenerIDExists indicates that a listener with the same id already exists.
|
||||
ErrListenerIDExists = errors.New("listener id already exists")
|
||||
|
||||
// ErrReadConnectInvalid indicates that the connection packet was invalid.
|
||||
ErrReadConnectInvalid = errors.New("connect packet was not valid")
|
||||
|
||||
// ErrConnectNotAuthorized indicates that the connection packet had incorrect auth values.
|
||||
ErrConnectNotAuthorized = errors.New("connect packet was not authorized")
|
||||
|
||||
// ErrInvalidTopic indicates that the specified topic was not valid.
|
||||
ErrInvalidTopic = errors.New("cannot publish to $ and $SYS topics")
|
||||
|
||||
// ErrRejectPacket indicates that a packet should be dropped instead of processed.
|
||||
ErrRejectPacket = errors.New("packet rejected")
|
||||
|
||||
// ErrClientDisconnect indicates that a client disconnected from the server.
|
||||
ErrClientDisconnect = errors.New("client disconnected")
|
||||
|
||||
// ErrClientReconnect indicates that a client attempted to reconnect while still connected.
|
||||
ErrClientReconnect = errors.New("client sent connect while connected")
|
||||
|
||||
// ErrServerShutdown is propagated when the server shuts down.
|
||||
ErrServerShutdown = errors.New("server is shutting down")
|
||||
|
||||
// ErrSessionReestablished indicates that an existing client was replaced by a newly connected
|
||||
// client. The existing client is disconnected.
|
||||
ErrSessionReestablished = errors.New("client session re-established")
|
||||
|
||||
// ErrConnectionFailed indicates that a client connection attempt failed for other reasons.
|
||||
ErrConnectionFailed = errors.New("connection attempt failed")
|
||||
|
||||
// SysTopicInterval is the number of milliseconds between $SYS topic publishes.
|
||||
SysTopicInterval time.Duration = 30000
|
||||
@@ -44,16 +73,26 @@ var (
|
||||
// Server is an MQTT broker server. It should be created with server.New()
|
||||
// in order to ensure all the internal fields are correctly populated.
|
||||
type Server struct {
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
Options *Options // configurable server options.
|
||||
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
|
||||
Clients *clients.Clients // clients which are known to the broker.
|
||||
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
|
||||
System *system.Info // values about the server commonly found in $SYS topics.
|
||||
Store persistence.Store // a persistent storage backend if desired.
|
||||
done chan bool // indicate that the server is ending.
|
||||
bytepool circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
|
||||
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
|
||||
inline inlineMessages // channels for direct publishing.
|
||||
Events events.Events // overrideable event hooks.
|
||||
done chan bool // indicate that the server is ending.
|
||||
}
|
||||
|
||||
// Options contains configurable options for the server.
|
||||
type Options struct {
|
||||
// BufferSize overrides the default buffer size (circ.DefaultBufferSize) for the client buffers.
|
||||
BufferSize int
|
||||
|
||||
// BufferBlockSize overrides the default buffer block size (DefaultBlockSize) for the client buffers.
|
||||
BufferBlockSize int
|
||||
}
|
||||
|
||||
// inlineMessages contains channels for handling inline (direct) publishing.
|
||||
@@ -62,11 +101,22 @@ type inlineMessages struct {
|
||||
pub chan packets.Packet // a channel of packets to publish to clients
|
||||
}
|
||||
|
||||
// New returns a new instance of an MQTT broker.
|
||||
// New returns a new instance of MQTT server with no options.
|
||||
// This method has been deprecated and will be removed in a future release.
|
||||
// Please use NewServer instead.
|
||||
func New() *Server {
|
||||
return NewServer(nil)
|
||||
}
|
||||
|
||||
// NewServer returns a new instance of an MQTT broker with optional values where applicable.
|
||||
func NewServer(opts *Options) *Server {
|
||||
if opts == nil {
|
||||
opts = new(Options)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
done: make(chan bool),
|
||||
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
|
||||
bytepool: circ.NewBytesPool(opts.BufferSize),
|
||||
Clients: clients.New(),
|
||||
Topics: topics.New(),
|
||||
System: &system.Info{
|
||||
@@ -78,7 +128,8 @@ func New() *Server {
|
||||
done: make(chan bool),
|
||||
pub: make(chan packets.Packet, 1024),
|
||||
},
|
||||
Events: events.Events{},
|
||||
Events: events.Events{},
|
||||
Options: opts,
|
||||
}
|
||||
|
||||
// Expose server stats using the system listener so it can be used in the
|
||||
@@ -165,118 +216,200 @@ func (s *Server) inlineClient() {
|
||||
}
|
||||
}
|
||||
|
||||
// readConnectionPacket reads the first incoming header for a connection, and if
|
||||
// acceptable, returns the valid connection packet.
|
||||
func (s *Server) readConnectionPacket(cl *clients.Client) (pk packets.Packet, err error) {
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pk, err = cl.ReadPacket(fh)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Type != packets.Connect {
|
||||
return pk, ErrReadConnectInvalid
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// onError is a pass-through method which triggers the OnError
|
||||
// event hook (if applicable), and returns the provided error.
|
||||
func (s *Server) onError(cl events.Client, err error) error {
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
// Note: if the error originates from a real cause, it will
|
||||
// have been captured as the StopCause. The two cases ignored
|
||||
// below are ordinary consequences of closing the connection.
|
||||
// If one of these ordinary conditions stops the connection,
|
||||
// then the client closed or broke the connection.
|
||||
if s.Events.OnError != nil &&
|
||||
!errors.Is(err, io.EOF) {
|
||||
s.Events.OnError(cl, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// onStorage is a pass-through method which delegates errors from
|
||||
// the persistent storage adapter to the onError event hook.
|
||||
func (s *Server) onStorage(cl events.Clientlike, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.onError(cl.Info(), fmt.Errorf("storage: %w", err))
|
||||
}
|
||||
|
||||
// EstablishConnection establishes a new client when a listener
|
||||
// accepts a new connection.
|
||||
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
||||
xbr := s.bytepool.Get() // Get byte buffer from pools for receiving packet data.
|
||||
xbw := s.bytepool.Get() // and for sending.
|
||||
defer s.bytepool.Put(xbr)
|
||||
defer s.bytepool.Put(xbw)
|
||||
|
||||
cl := clients.NewClient(c,
|
||||
circ.NewReaderFromSlice(0, xbr),
|
||||
circ.NewWriterFromSlice(0, xbw),
|
||||
circ.NewReaderFromSlice(s.Options.BufferBlockSize, xbr),
|
||||
circ.NewWriterFromSlice(s.Options.BufferBlockSize, xbw),
|
||||
s.System,
|
||||
)
|
||||
|
||||
cl.Start()
|
||||
defer cl.ClearBuffers()
|
||||
defer cl.Stop(nil)
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
pk, err := s.readConnectionPacket(cl)
|
||||
if err != nil {
|
||||
return err
|
||||
return s.onError(cl.Info(), fmt.Errorf("read connection: %w", err))
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
ackCode, err := pk.ConnectValidate()
|
||||
if err != nil {
|
||||
return err
|
||||
if err := s.ackConnection(cl, ackCode, false); err != nil {
|
||||
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
|
||||
}
|
||||
return s.onError(cl.Info(), fmt.Errorf("validate connection packet: %w", err))
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Type != packets.Connect {
|
||||
return ErrReadConnectInvalid
|
||||
}
|
||||
cl.Identify(lid, pk, ac) // Set client identity values from the connection packet.
|
||||
|
||||
cl.Identify(lid, pk, ac)
|
||||
|
||||
retcode, _ := pk.ConnectValidate()
|
||||
if !ac.Authenticate(pk.Username, pk.Password) {
|
||||
retcode = packets.CodeConnectBadAuthValues
|
||||
if err := s.ackConnection(cl, packets.CodeConnectBadAuthValues, false); err != nil {
|
||||
return s.onError(cl.Info(), fmt.Errorf("invalid connection send ack: %w", err))
|
||||
}
|
||||
return s.onError(cl.Info(), ErrConnectionFailed)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&s.System.ConnectionsTotal, 1)
|
||||
atomic.AddInt64(&s.System.ClientsConnected, 1)
|
||||
defer atomic.AddInt64(&s.System.ClientsConnected, -1)
|
||||
defer atomic.AddInt64(&s.System.ClientsDisconnected, 1)
|
||||
|
||||
var sessionPresent bool
|
||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
if atomic.LoadInt64(&existing.State.Done) == 1 {
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
|
||||
}
|
||||
existing.Stop()
|
||||
if pk.CleanSession {
|
||||
for k := range existing.Subscriptions {
|
||||
delete(existing.Subscriptions, k)
|
||||
q := s.Topics.Unsubscribe(k, existing.ID)
|
||||
if q {
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cl.Inflight = existing.Inflight // Inherit from existing session.
|
||||
cl.Subscriptions = existing.Subscriptions
|
||||
sessionPresent = true
|
||||
}
|
||||
existing.Unlock()
|
||||
} else {
|
||||
atomic.AddInt64(&s.System.ClientsTotal, 1)
|
||||
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
|
||||
atomic.AddInt64(&s.System.ClientsMax, 1)
|
||||
}
|
||||
sessionPresent := s.inheritClientSession(pk, cl)
|
||||
s.Clients.Add(cl)
|
||||
|
||||
err = s.ackConnection(cl, ackCode, sessionPresent)
|
||||
if err != nil {
|
||||
return s.onError(cl.Info(), fmt.Errorf("ack connection packet: %w", err))
|
||||
}
|
||||
|
||||
s.Clients.Add(cl) // Overwrite any existing client with the same name.
|
||||
|
||||
err = s.writeClient(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Connack,
|
||||
},
|
||||
SessionPresent: sessionPresent,
|
||||
ReturnCode: retcode,
|
||||
})
|
||||
if err != nil || retcode != packets.Accepted {
|
||||
return err
|
||||
err = s.ResendClientInflight(cl, true)
|
||||
if err != nil {
|
||||
s.onError(cl.Info(), fmt.Errorf("resend in flight: %w", err)) // pass-through, no return.
|
||||
}
|
||||
|
||||
s.ResendClientInflight(cl, true)
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.WriteClient(persistence.Client{
|
||||
s.onStorage(cl, s.Store.WriteClient(persistence.Client{
|
||||
ID: "cl_" + cl.ID,
|
||||
ClientID: cl.ID,
|
||||
T: persistence.KClient,
|
||||
Listener: cl.Listener,
|
||||
Username: cl.Username,
|
||||
LWT: persistence.LWT(cl.LWT),
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
err = cl.Read(s.processPacket)
|
||||
if err != nil {
|
||||
s.closeClient(cl, true)
|
||||
if s.Events.OnConnect != nil {
|
||||
s.Events.OnConnect(cl.Info(), events.Packet(pk))
|
||||
}
|
||||
|
||||
s.bytepool.Put(xbr) // Return byte buffers to pools when the client has finished.
|
||||
s.bytepool.Put(xbw)
|
||||
if err := cl.Read(s.processPacket); err != nil {
|
||||
s.sendLWT(cl)
|
||||
cl.Stop(err)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&s.System.ClientsConnected, -1)
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
|
||||
err = cl.StopCause() // Determine true cause of stop.
|
||||
|
||||
if s.Events.OnDisconnect != nil {
|
||||
s.Events.OnDisconnect(cl.Info(), err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ackConnection returns a Connack packet to a client.
|
||||
func (s *Server) ackConnection(cl *clients.Client, ack byte, present bool) error {
|
||||
return s.writeClient(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Connack,
|
||||
},
|
||||
SessionPresent: present,
|
||||
ReturnCode: ack,
|
||||
})
|
||||
}
|
||||
|
||||
// inheritClientSession inherits the state of an existing client sharing the same
|
||||
// connection ID. If cleanSession is true, the state of any previously existing client
|
||||
// session is abandoned.
|
||||
func (s *Server) inheritClientSession(pk packets.Packet, cl *clients.Client) bool {
|
||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
defer existing.Unlock()
|
||||
|
||||
existing.Stop(ErrSessionReestablished) // Issue a stop on the old client.
|
||||
|
||||
// Per [MQTT-3.1.2-6]:
|
||||
// If CleanSession is set to 1, the Client and Server MUST discard any previous Session and start a new one.
|
||||
// The state associated with a CleanSession MUST NOT be reused in any subsequent session.
|
||||
if pk.CleanSession || existing.CleanSession {
|
||||
s.unsubscribeClient(existing)
|
||||
return false
|
||||
}
|
||||
|
||||
cl.Inflight = existing.Inflight // Take address of existing session.
|
||||
cl.Subscriptions = existing.Subscriptions
|
||||
return true
|
||||
|
||||
} else {
|
||||
atomic.AddInt64(&s.System.ClientsTotal, 1)
|
||||
if atomic.LoadInt64(&s.System.ClientsConnected) > atomic.LoadInt64(&s.System.ClientsMax) {
|
||||
atomic.AddInt64(&s.System.ClientsMax, 1)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// unsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) unsubscribeClient(cl *clients.Client) {
|
||||
for k := range cl.Subscriptions {
|
||||
delete(cl.Subscriptions, k)
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeClient writes packets to a client connection.
|
||||
func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
|
||||
_, err := cl.WritePacket(pk)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -327,13 +460,14 @@ func (s *Server) processPacket(cl *clients.Client, pk packets.Packet) error {
|
||||
// establish a new connection on an existing connection. See EstablishConnection
|
||||
// instead.
|
||||
func (s *Server) processConnect(cl *clients.Client, pk packets.Packet) error {
|
||||
s.closeClient(cl, true)
|
||||
s.sendLWT(cl)
|
||||
cl.Stop(ErrClientReconnect)
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDisconnect processes a Disconnect packet.
|
||||
func (s *Server) processDisconnect(cl *clients.Client, pk packets.Packet) error {
|
||||
s.closeClient(cl, false)
|
||||
cl.Stop(ErrClientDisconnect)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -362,14 +496,15 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Type: packets.Publish,
|
||||
Retain: retain,
|
||||
},
|
||||
TopicName: topic,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if retain {
|
||||
s.retainMessage(pk)
|
||||
s.retainMessage(&s.inline, pk)
|
||||
}
|
||||
|
||||
// handoff packet to s.inline.pub channel for writing to client buffers
|
||||
@@ -379,6 +514,16 @@ func (s *Server) Publish(topic string, payload []byte, retain bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Info provides pseudo-client information for the inline messages processor.
|
||||
// It provides a 'client' to which inline retained messages can be assigned.
|
||||
func (*inlineMessages) Info() events.Client {
|
||||
return events.Client{
|
||||
ID: "inline",
|
||||
Remote: "inline",
|
||||
Listener: "inline",
|
||||
}
|
||||
}
|
||||
|
||||
// processPublish processes a Publish packet.
|
||||
func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
if len(pk.TopicName) >= 4 && pk.TopicName[0:4] == "$SYS" {
|
||||
@@ -389,8 +534,25 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if an OnProcessMessage hook exists, potentially modify the packet.
|
||||
if s.Events.OnProcessMessage != nil {
|
||||
pkx, err := s.Events.OnProcessMessage(cl.Info(), events.Packet(pk))
|
||||
if err == nil {
|
||||
pk = packets.Packet(pkx) // Only use the new package changes if there's no errors.
|
||||
} else {
|
||||
// If the ErrRejectPacket is return, abandon processing the packet.
|
||||
if err == ErrRejectPacket {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.Events.OnError != nil {
|
||||
s.Events.OnError(cl.Info(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Retain {
|
||||
s.retainMessage(pk)
|
||||
s.retainMessage(cl, pk)
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Qos > 0 {
|
||||
@@ -407,12 +569,12 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
|
||||
// omit errors in case of broken connection / LWT publish. ack send failures
|
||||
// will be handled by in-flight resending on next reconnect.
|
||||
s.writeClient(cl, ack)
|
||||
s.onError(cl.Info(), s.writeClient(cl, ack))
|
||||
}
|
||||
|
||||
// if an OnMessage hook exists, potentially modify the packet.
|
||||
if s.Events.OnMessage != nil {
|
||||
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
|
||||
if pkx, err := s.Events.OnMessage(cl.Info(), events.Packet(pk)); err == nil {
|
||||
pk = packets.Packet(pkx)
|
||||
}
|
||||
}
|
||||
@@ -425,21 +587,23 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
|
||||
// retainMessage adds a message to a topic, and if a persistent store is provided,
|
||||
// adds the message to the store so it can be reloaded if necessary.
|
||||
func (s *Server) retainMessage(pk packets.Packet) {
|
||||
func (s *Server) retainMessage(cl events.Clientlike, pk packets.Packet) {
|
||||
out := pk.PublishCopy()
|
||||
q := s.Topics.RetainMessage(out)
|
||||
atomic.AddInt64(&s.System.Retained, q)
|
||||
r := s.Topics.RetainMessage(out)
|
||||
atomic.AddInt64(&s.System.Retained, r)
|
||||
|
||||
if s.Store != nil {
|
||||
if q == 1 {
|
||||
s.Store.WriteRetained(persistence.Message{
|
||||
ID: "ret_" + out.TopicName,
|
||||
id := "ret_" + out.TopicName
|
||||
if r == 1 {
|
||||
s.onStorage(cl, s.Store.WriteRetained(persistence.Message{
|
||||
ID: id,
|
||||
T: persistence.KRetained,
|
||||
FixedHeader: persistence.FixedHeader(out.FixedHeader),
|
||||
TopicName: out.TopicName,
|
||||
Payload: out.Payload,
|
||||
})
|
||||
}))
|
||||
} else {
|
||||
s.Store.DeleteRetained("ret_" + out.TopicName)
|
||||
s.onStorage(cl, s.Store.DeleteRetained(id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -449,6 +613,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
|
||||
func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
|
||||
if client, ok := s.Clients.Get(id); ok {
|
||||
|
||||
// If the AllowClients value is set, only deliver the packet if the subscribed
|
||||
// client exists in the AllowClients value. For use with the OnMessage event hook
|
||||
// in cases where you want to publish messages to clients selectively.
|
||||
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
|
||||
continue
|
||||
}
|
||||
|
||||
out := pk.PublishCopy()
|
||||
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
|
||||
out.FixedHeader.Qos = qos
|
||||
@@ -473,18 +645,18 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.WriteInflight(persistence.Message{
|
||||
ID: "if_" + client.ID + "_" + strconv.Itoa(int(out.PacketID)),
|
||||
T: persistence.KRetained,
|
||||
s.onStorage(client, s.Store.WriteInflight(persistence.Message{
|
||||
ID: persistentID(client, out),
|
||||
T: persistence.KInflight,
|
||||
FixedHeader: persistence.FixedHeader(out.FixedHeader),
|
||||
TopicName: out.TopicName,
|
||||
Payload: out.Payload,
|
||||
Sent: sent,
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
s.writeClient(client, out)
|
||||
s.onError(client.Info(), s.writeClient(client, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -496,7 +668,7 @@ func (s *Server) processPuback(cl *clients.Client, pk packets.Packet) error {
|
||||
atomic.AddInt64(&s.System.Inflight, -1)
|
||||
}
|
||||
if s.Store != nil {
|
||||
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
|
||||
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -538,7 +710,7 @@ func (s *Server) processPubrel(cl *clients.Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
|
||||
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -551,7 +723,7 @@ func (s *Server) processPubcomp(cl *clients.Client, pk packets.Packet) error {
|
||||
atomic.AddInt64(&s.System.Inflight, -1)
|
||||
}
|
||||
if s.Store != nil {
|
||||
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(pk.PacketID)))
|
||||
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, pk)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -571,13 +743,13 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
|
||||
retCodes[i] = pk.Qoss[i]
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.WriteSubscription(persistence.Subscription{
|
||||
s.onStorage(cl, s.Store.WriteSubscription(persistence.Subscription{
|
||||
ID: "sub_" + cl.ID + ":" + pk.Topics[i],
|
||||
T: persistence.KSubscription,
|
||||
Filter: pk.Topics[i],
|
||||
Client: cl.ID,
|
||||
QoS: pk.Qoss[i],
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -593,10 +765,15 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Publish out any retained messages matching the subscription filter.
|
||||
// Publish out any retained messages matching the subscription filter and the user has
|
||||
// been allowed to subscribe to.
|
||||
for i := 0; i < len(pk.Topics); i++ {
|
||||
if retCodes[i] == packets.ErrSubAckNetworkError {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pkv := range s.Topics.Messages(pk.Topics[i]) {
|
||||
s.writeClient(cl, pkv) // omit errors, prefer continuing.
|
||||
s.onError(cl.Info(), s.writeClient(cl, pkv))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -626,6 +803,17 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// atomicItoa reads an *int64 and formats a decimal string.
|
||||
func atomicItoa(ptr *int64) string {
|
||||
return strconv.FormatInt(atomic.LoadInt64(ptr), 10)
|
||||
}
|
||||
|
||||
// persistentID return a string combining the client and packet
|
||||
// identifiers for use with the persistence layer.
|
||||
func persistentID(client *clients.Client, pk packets.Packet) string {
|
||||
return "if_" + client.ID + "_" + pk.FormatID()
|
||||
}
|
||||
|
||||
// publishSysTopics publishes the current values to the server $SYS topics.
|
||||
// Due to the int to string conversions this method is not as cheap as
|
||||
// some of the others so the publishing interval should be set appropriately.
|
||||
@@ -637,26 +825,27 @@ func (s *Server) publishSysTopics() {
|
||||
},
|
||||
}
|
||||
|
||||
s.System.Uptime = time.Now().Unix() - s.System.Started
|
||||
uptime := time.Now().Unix() - atomic.LoadInt64(&s.System.Started)
|
||||
atomic.StoreInt64(&s.System.Uptime, uptime)
|
||||
topics := map[string]string{
|
||||
"$SYS/broker/version": s.System.Version,
|
||||
"$SYS/broker/uptime": strconv.Itoa(int(s.System.Uptime)),
|
||||
"$SYS/broker/timestamp": strconv.Itoa(int(s.System.Started)),
|
||||
"$SYS/broker/load/bytes/received": strconv.Itoa(int(s.System.BytesRecv)),
|
||||
"$SYS/broker/load/bytes/sent": strconv.Itoa(int(s.System.BytesSent)),
|
||||
"$SYS/broker/clients/connected": strconv.Itoa(int(s.System.ClientsConnected)),
|
||||
"$SYS/broker/clients/disconnected": strconv.Itoa(int(s.System.ClientsDisconnected)),
|
||||
"$SYS/broker/clients/maximum": strconv.Itoa(int(s.System.ClientsMax)),
|
||||
"$SYS/broker/clients/total": strconv.Itoa(int(s.System.ClientsTotal)),
|
||||
"$SYS/broker/connections/total": strconv.Itoa(int(s.System.ConnectionsTotal)),
|
||||
"$SYS/broker/messages/received": strconv.Itoa(int(s.System.MessagesRecv)),
|
||||
"$SYS/broker/messages/sent": strconv.Itoa(int(s.System.MessagesSent)),
|
||||
"$SYS/broker/messages/publish/dropped": strconv.Itoa(int(s.System.PublishDropped)),
|
||||
"$SYS/broker/messages/publish/received": strconv.Itoa(int(s.System.PublishRecv)),
|
||||
"$SYS/broker/messages/publish/sent": strconv.Itoa(int(s.System.PublishSent)),
|
||||
"$SYS/broker/messages/retained/count": strconv.Itoa(int(s.System.Retained)),
|
||||
"$SYS/broker/messages/inflight": strconv.Itoa(int(s.System.Inflight)),
|
||||
"$SYS/broker/subscriptions/count": strconv.Itoa(int(s.System.Subscriptions)),
|
||||
"$SYS/broker/uptime": atomicItoa(&s.System.Uptime),
|
||||
"$SYS/broker/timestamp": atomicItoa(&s.System.Started),
|
||||
"$SYS/broker/load/bytes/received": atomicItoa(&s.System.BytesRecv),
|
||||
"$SYS/broker/load/bytes/sent": atomicItoa(&s.System.BytesSent),
|
||||
"$SYS/broker/clients/connected": atomicItoa(&s.System.ClientsConnected),
|
||||
"$SYS/broker/clients/disconnected": atomicItoa(&s.System.ClientsDisconnected),
|
||||
"$SYS/broker/clients/maximum": atomicItoa(&s.System.ClientsMax),
|
||||
"$SYS/broker/clients/total": atomicItoa(&s.System.ClientsTotal),
|
||||
"$SYS/broker/connections/total": atomicItoa(&s.System.ConnectionsTotal),
|
||||
"$SYS/broker/messages/received": atomicItoa(&s.System.MessagesRecv),
|
||||
"$SYS/broker/messages/sent": atomicItoa(&s.System.MessagesSent),
|
||||
"$SYS/broker/messages/publish/dropped": atomicItoa(&s.System.PublishDropped),
|
||||
"$SYS/broker/messages/publish/received": atomicItoa(&s.System.PublishRecv),
|
||||
"$SYS/broker/messages/publish/sent": atomicItoa(&s.System.PublishSent),
|
||||
"$SYS/broker/messages/retained/count": atomicItoa(&s.System.Retained),
|
||||
"$SYS/broker/messages/inflight": atomicItoa(&s.System.Inflight),
|
||||
"$SYS/broker/subscriptions/count": atomicItoa(&s.System.Subscriptions),
|
||||
}
|
||||
|
||||
for topic, payload := range topics {
|
||||
@@ -668,10 +857,10 @@ func (s *Server) publishSysTopics() {
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.WriteServerInfo(persistence.ServerInfo{
|
||||
s.onStorage(&s.inline, s.Store.WriteServerInfo(persistence.ServerInfo{
|
||||
Info: *s.System,
|
||||
ID: persistence.KServerInfo,
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +880,7 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.DeleteInflight("if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)))
|
||||
s.onStorage(cl, s.Store.DeleteInflight(persistentID(cl, tk.Packet)))
|
||||
}
|
||||
|
||||
continue
|
||||
@@ -715,15 +904,15 @@ func (s *Server) ResendClientInflight(cl *clients.Client, force bool) error {
|
||||
}
|
||||
|
||||
if s.Store != nil {
|
||||
s.Store.WriteInflight(persistence.Message{
|
||||
ID: "if_" + cl.ID + "_" + strconv.Itoa(int(tk.Packet.PacketID)),
|
||||
T: persistence.KRetained,
|
||||
s.onStorage(cl, s.Store.WriteInflight(persistence.Message{
|
||||
ID: persistentID(cl, tk.Packet),
|
||||
T: persistence.KInflight,
|
||||
FixedHeader: persistence.FixedHeader(tk.Packet.FixedHeader),
|
||||
TopicName: tk.Packet.TopicName,
|
||||
Payload: tk.Packet.Payload,
|
||||
Sent: tk.Sent,
|
||||
Resends: tk.Resends,
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -746,15 +935,14 @@ func (s *Server) Close() error {
|
||||
func (s *Server) closeListenerClients(listener string) {
|
||||
clients := s.Clients.GetByListener(listener)
|
||||
for _, cl := range clients {
|
||||
s.closeClient(cl, false) // omit errors
|
||||
cl.Stop(ErrServerShutdown)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// closeClient closes a client connection and publishes any LWT messages.
|
||||
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
||||
if sendLWT && cl.LWT.Topic != "" {
|
||||
s.processPublish(cl, packets.Packet{
|
||||
// sendLWT issues an LWT message to a topic when a client disconnects.
|
||||
func (s *Server) sendLWT(cl *clients.Client) error {
|
||||
if cl.LWT.Topic != "" {
|
||||
err := s.processPublish(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Retain: cl.LWT.Retain,
|
||||
@@ -763,10 +951,11 @@ func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
||||
TopicName: cl.LWT.Topic,
|
||||
Payload: cl.LWT.Message,
|
||||
})
|
||||
if err != nil {
|
||||
return s.onError(cl.Info(), fmt.Errorf("send lwt: %s %w; %+v", cl.ID, err, cl.LWT))
|
||||
}
|
||||
}
|
||||
|
||||
cl.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
21
server/system/system_test.go
Normal file
21
server/system/system_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInfoAlignment(t *testing.T) {
|
||||
typ := reflect.TypeOf(Info{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
switch f.Type.Kind() {
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
require.Equalf(t, uintptr(0), f.Offset%8,
|
||||
"%s requires 64-bit alignment for atomic: offset %d",
|
||||
f.Name, f.Offset)
|
||||
}
|
||||
}
|
||||
}
|
39
vendor/github.com/gorilla/websocket/README.md
generated
vendored
39
vendor/github.com/gorilla/websocket/README.md
generated
vendored
@@ -6,6 +6,13 @@
|
||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
||||
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||
|
||||
|
||||
---
|
||||
|
||||
⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)**
|
||||
|
||||
---
|
||||
|
||||
### Documentation
|
||||
|
||||
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||
@@ -30,35 +37,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||
|
||||
### Gorilla WebSocket compared with other packages
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
|
||||
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
|
||||
</tr>
|
||||
<tr>
|
||||
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
|
||||
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
|
||||
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
|
||||
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
|
||||
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
|
||||
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
|
||||
<tr><td colspan="3">Other Features</tr></td>
|
||||
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
|
||||
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
|
||||
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
|
||||
</table>
|
||||
|
||||
Notes:
|
||||
|
||||
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
|
||||
2. The application can get the type of a received data message by implementing
|
||||
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
|
||||
function.
|
||||
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
|
||||
Read returns when the input buffer is full or a frame boundary is
|
||||
encountered. Each call to Write sends a single frame message. The Gorilla
|
||||
io.Reader and io.WriteCloser operate on a single WebSocket message.
|
||||
|
||||
|
77
vendor/github.com/gorilla/websocket/client.go
generated
vendored
77
vendor/github.com/gorilla/websocket/client.go
generated
vendored
@@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
|
||||
}
|
||||
|
||||
// A Dialer contains options for connecting to WebSocket server.
|
||||
//
|
||||
// It is safe to call Dialer's methods concurrently.
|
||||
type Dialer struct {
|
||||
// NetDial specifies the dial function for creating TCP connections. If
|
||||
// NetDial is nil, net.Dial is used.
|
||||
NetDial func(network, addr string) (net.Conn, error)
|
||||
|
||||
// NetDialContext specifies the dial function for creating TCP connections. If
|
||||
// NetDialContext is nil, net.DialContext is used.
|
||||
// NetDialContext is nil, NetDial is used.
|
||||
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
|
||||
// NetDialTLSContext is nil, NetDialContext is used.
|
||||
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
|
||||
// TLSClientConfig is ignored.
|
||||
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Proxy specifies a function to return a proxy for a given
|
||||
// Request. If the function returns a non-nil error, the
|
||||
// request is aborted with the provided error.
|
||||
@@ -65,6 +73,8 @@ type Dialer struct {
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
|
||||
// If nil, the default configuration is used.
|
||||
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
|
||||
// is done there and TLSClientConfig is ignored.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||
@@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
Method: http.MethodGet,
|
||||
URL: u,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
@@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
// Get network dial function.
|
||||
var netDial func(network, add string) (net.Conn, error)
|
||||
|
||||
if d.NetDialContext != nil {
|
||||
netDial = func(network, addr string) (net.Conn, error) {
|
||||
return d.NetDialContext(ctx, network, addr)
|
||||
switch u.Scheme {
|
||||
case "http":
|
||||
if d.NetDialContext != nil {
|
||||
netDial = func(network, addr string) (net.Conn, error) {
|
||||
return d.NetDialContext(ctx, network, addr)
|
||||
}
|
||||
} else if d.NetDial != nil {
|
||||
netDial = d.NetDial
|
||||
}
|
||||
} else if d.NetDial != nil {
|
||||
netDial = d.NetDial
|
||||
} else {
|
||||
case "https":
|
||||
if d.NetDialTLSContext != nil {
|
||||
netDial = func(network, addr string) (net.Conn, error) {
|
||||
return d.NetDialTLSContext(ctx, network, addr)
|
||||
}
|
||||
} else if d.NetDialContext != nil {
|
||||
netDial = func(network, addr string) (net.Conn, error) {
|
||||
return d.NetDialContext(ctx, network, addr)
|
||||
}
|
||||
} else if d.NetDial != nil {
|
||||
netDial = d.NetDial
|
||||
}
|
||||
default:
|
||||
return nil, nil, errMalformedURL
|
||||
}
|
||||
|
||||
if netDial == nil {
|
||||
netDialer := &net.Dialer{}
|
||||
netDial = func(network, addr string) (net.Conn, error) {
|
||||
return netDialer.DialContext(ctx, network, addr)
|
||||
@@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
}
|
||||
}()
|
||||
|
||||
if u.Scheme == "https" {
|
||||
if u.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
|
||||
|
||||
cfg := cloneTLSConfig(d.TLSClientConfig)
|
||||
if cfg.ServerName == "" {
|
||||
cfg.ServerName = hostNoPort
|
||||
@@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
tlsConn := tls.Client(netConn, cfg)
|
||||
netConn = tlsConn
|
||||
|
||||
var err error
|
||||
if trace != nil {
|
||||
err = doHandshakeWithTrace(trace, tlsConn, cfg)
|
||||
} else {
|
||||
err = doHandshake(tlsConn, cfg)
|
||||
if trace != nil && trace.TLSHandshakeStart != nil {
|
||||
trace.TLSHandshakeStart()
|
||||
}
|
||||
err := doHandshake(ctx, tlsConn, cfg)
|
||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 ||
|
||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||
// Before closing the network connection on return from this
|
||||
// function, slurp up some of the response to aid application
|
||||
@@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||
return conn, resp, nil
|
||||
}
|
||||
|
||||
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return err
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return cfg.Clone()
|
||||
}
|
||||
|
16
vendor/github.com/gorilla/websocket/client_clone.go
generated
vendored
16
vendor/github.com/gorilla/websocket/client_clone.go
generated
vendored
@@ -1,16 +0,0 @@
|
||||
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return cfg.Clone()
|
||||
}
|
38
vendor/github.com/gorilla/websocket/client_clone_legacy.go
generated
vendored
38
vendor/github.com/gorilla/websocket/client_clone_legacy.go
generated
vendored
@@ -1,38 +0,0 @@
|
||||
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// cloneTLSConfig clones all public fields except the fields
|
||||
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
|
||||
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
|
||||
// config in active use.
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return &tls.Config{
|
||||
Rand: cfg.Rand,
|
||||
Time: cfg.Time,
|
||||
Certificates: cfg.Certificates,
|
||||
NameToCertificate: cfg.NameToCertificate,
|
||||
GetCertificate: cfg.GetCertificate,
|
||||
RootCAs: cfg.RootCAs,
|
||||
NextProtos: cfg.NextProtos,
|
||||
ServerName: cfg.ServerName,
|
||||
ClientAuth: cfg.ClientAuth,
|
||||
ClientCAs: cfg.ClientCAs,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
CipherSuites: cfg.CipherSuites,
|
||||
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||
ClientSessionCache: cfg.ClientSessionCache,
|
||||
MinVersion: cfg.MinVersion,
|
||||
MaxVersion: cfg.MaxVersion,
|
||||
CurvePreferences: cfg.CurvePreferences,
|
||||
}
|
||||
}
|
63
vendor/github.com/gorilla/websocket/conn.go
generated
vendored
63
vendor/github.com/gorilla/websocket/conn.go
generated
vendored
@@ -13,6 +13,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
||||
b := net.Buffers(bufs)
|
||||
_, err := b.WriteTo(c.conn)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteControl writes a control message with the given deadline. The allowed
|
||||
// message types are CloseMessage, PingMessage and PongMessage.
|
||||
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
|
||||
@@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||
}
|
||||
|
||||
// 2. Read and parse first two bytes of frame header.
|
||||
// To aid debugging, collect and report all errors in the first two bytes
|
||||
// of the header.
|
||||
|
||||
var errors []string
|
||||
|
||||
p, err := c.read(2)
|
||||
if err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
|
||||
final := p[0]&finalBit != 0
|
||||
frameType := int(p[0] & 0xf)
|
||||
final := p[0]&finalBit != 0
|
||||
rsv1 := p[0]&rsv1Bit != 0
|
||||
rsv2 := p[0]&rsv2Bit != 0
|
||||
rsv3 := p[0]&rsv3Bit != 0
|
||||
mask := p[1]&maskBit != 0
|
||||
c.setReadRemaining(int64(p[1] & 0x7f))
|
||||
|
||||
c.readDecompress = false
|
||||
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
||||
c.readDecompress = true
|
||||
p[0] &^= rsv1Bit
|
||||
if rsv1 {
|
||||
if c.newDecompressionReader != nil {
|
||||
c.readDecompress = true
|
||||
} else {
|
||||
errors = append(errors, "RSV1 set")
|
||||
}
|
||||
}
|
||||
|
||||
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
||||
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
||||
if rsv2 {
|
||||
errors = append(errors, "RSV2 set")
|
||||
}
|
||||
|
||||
if rsv3 {
|
||||
errors = append(errors, "RSV3 set")
|
||||
}
|
||||
|
||||
switch frameType {
|
||||
case CloseMessage, PingMessage, PongMessage:
|
||||
if c.readRemaining > maxControlFramePayloadSize {
|
||||
return noFrame, c.handleProtocolError("control frame length > 125")
|
||||
errors = append(errors, "len > 125 for control")
|
||||
}
|
||||
if !final {
|
||||
return noFrame, c.handleProtocolError("control frame not final")
|
||||
errors = append(errors, "FIN not set on control")
|
||||
}
|
||||
case TextMessage, BinaryMessage:
|
||||
if !c.readFinal {
|
||||
return noFrame, c.handleProtocolError("message start before final message frame")
|
||||
errors = append(errors, "data before FIN")
|
||||
}
|
||||
c.readFinal = final
|
||||
case continuationFrame:
|
||||
if c.readFinal {
|
||||
return noFrame, c.handleProtocolError("continuation after final message frame")
|
||||
errors = append(errors, "continuation after FIN")
|
||||
}
|
||||
c.readFinal = final
|
||||
default:
|
||||
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
||||
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
|
||||
}
|
||||
|
||||
if mask != c.isServer {
|
||||
errors = append(errors, "bad MASK")
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
|
||||
}
|
||||
|
||||
// 3. Read and parse frame length as per
|
||||
@@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||
|
||||
// 4. Handle frame masking.
|
||||
|
||||
if mask != c.isServer {
|
||||
return noFrame, c.handleProtocolError("incorrect mask flag")
|
||||
}
|
||||
|
||||
if mask {
|
||||
c.readMaskPos = 0
|
||||
p, err := c.read(len(c.readMaskKey))
|
||||
@@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||
if len(payload) >= 2 {
|
||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||
if !isValidReceivedCloseCode(closeCode) {
|
||||
return noFrame, c.handleProtocolError("invalid close code")
|
||||
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
|
||||
}
|
||||
closeText = string(payload[2:])
|
||||
if !utf8.ValidString(closeText) {
|
||||
@@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) handleProtocolError(message string) error {
|
||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
|
||||
data := FormatCloseMessage(CloseProtocolError, message)
|
||||
if len(data) > maxControlFramePayloadSize {
|
||||
data = data[:maxControlFramePayloadSize]
|
||||
}
|
||||
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
||||
return errors.New("websocket: " + message)
|
||||
}
|
||||
|
||||
|
15
vendor/github.com/gorilla/websocket/conn_write.go
generated
vendored
15
vendor/github.com/gorilla/websocket/conn_write.go
generated
vendored
@@ -1,15 +0,0 @@
|
||||
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
import "net"
|
||||
|
||||
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
||||
b := net.Buffers(bufs)
|
||||
_, err := b.WriteTo(c.conn)
|
||||
return err
|
||||
}
|
18
vendor/github.com/gorilla/websocket/conn_write_legacy.go
generated
vendored
18
vendor/github.com/gorilla/websocket/conn_write_legacy.go
generated
vendored
@@ -1,18 +0,0 @@
|
||||
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
||||
for _, buf := range bufs {
|
||||
if len(buf) > 0 {
|
||||
if _, err := c.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
1
vendor/github.com/gorilla/websocket/mask.go
generated
vendored
1
vendor/github.com/gorilla/websocket/mask.go
generated
vendored
@@ -2,6 +2,7 @@
|
||||
// this source code is governed by a BSD-style license that can be found in the
|
||||
// LICENSE file.
|
||||
|
||||
//go:build !appengine
|
||||
// +build !appengine
|
||||
|
||||
package websocket
|
||||
|
1
vendor/github.com/gorilla/websocket/mask_safe.go
generated
vendored
1
vendor/github.com/gorilla/websocket/mask_safe.go
generated
vendored
@@ -2,6 +2,7 @@
|
||||
// this source code is governed by a BSD-style license that can be found in the
|
||||
// LICENSE file.
|
||||
|
||||
//go:build appengine
|
||||
// +build appengine
|
||||
|
||||
package websocket
|
||||
|
2
vendor/github.com/gorilla/websocket/proxy.go
generated
vendored
2
vendor/github.com/gorilla/websocket/proxy.go
generated
vendored
@@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
connectReq := &http.Request{
|
||||
Method: "CONNECT",
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Opaque: addr},
|
||||
Host: addr,
|
||||
Header: connectHeader,
|
||||
|
8
vendor/github.com/gorilla/websocket/server.go
generated
vendored
8
vendor/github.com/gorilla/websocket/server.go
generated
vendored
@@ -23,6 +23,8 @@ func (e HandshakeError) Error() string { return e.message }
|
||||
|
||||
// Upgrader specifies parameters for upgrading an HTTP connection to a
|
||||
// WebSocket connection.
|
||||
//
|
||||
// It is safe to call Upgrader's methods concurrently.
|
||||
type Upgrader struct {
|
||||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||
HandshakeTimeout time.Duration
|
||||
@@ -115,8 +117,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||
//
|
||||
// The responseHeader is included in the response to the client's upgrade
|
||||
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
||||
// application negotiated subprotocol (Sec-WebSocket-Protocol).
|
||||
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
|
||||
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
|
||||
//
|
||||
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
||||
// response.
|
||||
@@ -131,7 +133,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
if r.Method != http.MethodGet {
|
||||
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
|
||||
}
|
||||
|
||||
|
21
vendor/github.com/gorilla/websocket/tls_handshake.go
generated
vendored
Normal file
21
vendor/github.com/gorilla/websocket/tls_handshake.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
21
vendor/github.com/gorilla/websocket/tls_handshake_116.go
generated
vendored
Normal file
21
vendor/github.com/gorilla/websocket/tls_handshake_116.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build !go1.17
|
||||
// +build !go1.17
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
19
vendor/github.com/gorilla/websocket/trace.go
generated
vendored
19
vendor/github.com/gorilla/websocket/trace.go
generated
vendored
@@ -1,19 +0,0 @@
|
||||
// +build go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http/httptrace"
|
||||
)
|
||||
|
||||
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||
if trace.TLSHandshakeStart != nil {
|
||||
trace.TLSHandshakeStart()
|
||||
}
|
||||
err := doHandshake(tlsConn, cfg)
|
||||
if trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
||||
}
|
||||
return err
|
||||
}
|
12
vendor/github.com/gorilla/websocket/trace_17.go
generated
vendored
12
vendor/github.com/gorilla/websocket/trace_17.go
generated
vendored
@@ -1,12 +0,0 @@
|
||||
// +build !go1.8
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http/httptrace"
|
||||
)
|
||||
|
||||
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||
return doHandshake(tlsConn, cfg)
|
||||
}
|
125
vendor/github.com/jinzhu/copier/copier.go
generated
vendored
125
vendor/github.com/jinzhu/copier/copier.go
generated
vendored
@@ -24,6 +24,13 @@ const (
|
||||
|
||||
// Denotes that the value as been copied
|
||||
hasCopied
|
||||
|
||||
// Some default converter types for a nicer syntax
|
||||
String string = ""
|
||||
Bool bool = false
|
||||
Int int = 0
|
||||
Float32 float32 = 0
|
||||
Float64 float64 = 0
|
||||
)
|
||||
|
||||
// Option sets copy options
|
||||
@@ -32,6 +39,18 @@ type Option struct {
|
||||
// struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
|
||||
IgnoreEmpty bool
|
||||
DeepCopy bool
|
||||
Converters []TypeConverter
|
||||
}
|
||||
|
||||
type TypeConverter struct {
|
||||
SrcType interface{}
|
||||
DstType interface{}
|
||||
Fn func(src interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
type converterPair struct {
|
||||
SrcType reflect.Type
|
||||
DstType reflect.Type
|
||||
}
|
||||
|
||||
// Tag Flags
|
||||
@@ -59,12 +78,27 @@ func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err
|
||||
|
||||
func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) {
|
||||
var (
|
||||
isSlice bool
|
||||
amount = 1
|
||||
from = indirect(reflect.ValueOf(fromValue))
|
||||
to = indirect(reflect.ValueOf(toValue))
|
||||
isSlice bool
|
||||
amount = 1
|
||||
from = indirect(reflect.ValueOf(fromValue))
|
||||
to = indirect(reflect.ValueOf(toValue))
|
||||
converters map[converterPair]TypeConverter
|
||||
)
|
||||
|
||||
// save convertes into map for faster lookup
|
||||
for i := range opt.Converters {
|
||||
if converters == nil {
|
||||
converters = make(map[converterPair]TypeConverter)
|
||||
}
|
||||
|
||||
pair := converterPair{
|
||||
SrcType: reflect.TypeOf(opt.Converters[i].SrcType),
|
||||
DstType: reflect.TypeOf(opt.Converters[i].DstType),
|
||||
}
|
||||
|
||||
converters[pair] = opt.Converters[i]
|
||||
}
|
||||
|
||||
if !to.CanAddr() {
|
||||
return ErrInvalidCopyDestination
|
||||
}
|
||||
@@ -113,13 +147,16 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
|
||||
for _, k := range from.MapKeys() {
|
||||
toKey := indirect(reflect.New(toType.Key()))
|
||||
if !set(toKey, k, opt.DeepCopy) {
|
||||
if !set(toKey, k, opt.DeepCopy, converters) {
|
||||
return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
|
||||
}
|
||||
|
||||
elemType, _ := indirectType(toType.Elem())
|
||||
elemType := toType.Elem()
|
||||
if elemType.Kind() != reflect.Slice {
|
||||
elemType, _ = indirectType(elemType)
|
||||
}
|
||||
toValue := indirect(reflect.New(elemType))
|
||||
if !set(toValue, from.MapIndex(k), opt.DeepCopy) {
|
||||
if !set(toValue, from.MapIndex(k), opt.DeepCopy, converters) {
|
||||
if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -148,7 +185,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
|
||||
}
|
||||
|
||||
if !set(to.Index(i), from.Index(i), opt.DeepCopy) {
|
||||
if !set(to.Index(i), from.Index(i), opt.DeepCopy, converters) {
|
||||
// ignore error while copy slice element
|
||||
err = copier(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
|
||||
if err != nil {
|
||||
@@ -203,6 +240,8 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
|
||||
// check source
|
||||
if source.IsValid() {
|
||||
copyUnexportedStructFields(dest, source)
|
||||
|
||||
// Copy from source field to dest field or method
|
||||
fromTypeFields := deepFields(fromType)
|
||||
for _, field := range fromTypeFields {
|
||||
@@ -249,7 +288,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
toField := dest.FieldByName(destFieldName)
|
||||
if toField.IsValid() {
|
||||
if toField.CanSet() {
|
||||
if !set(toField, fromField, opt.DeepCopy) {
|
||||
if !set(toField, fromField, opt.DeepCopy, converters) {
|
||||
if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -291,7 +330,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
if toField := dest.FieldByName(destFieldName); toField.IsValid() && toField.CanSet() {
|
||||
values := fromMethod.Call([]reflect.Value{})
|
||||
if len(values) >= 1 {
|
||||
set(toField, values[0], opt.DeepCopy)
|
||||
set(toField, values[0], opt.DeepCopy, converters)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -303,7 +342,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
if to.Len() < i+1 {
|
||||
to.Set(reflect.Append(to, dest.Addr()))
|
||||
} else {
|
||||
if !set(to.Index(i), dest.Addr(), opt.DeepCopy) {
|
||||
if !set(to.Index(i), dest.Addr(), opt.DeepCopy, converters) {
|
||||
// ignore error while copy slice element
|
||||
err = copier(to.Index(i).Addr().Interface(), dest.Addr().Interface(), opt)
|
||||
if err != nil {
|
||||
@@ -315,7 +354,7 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
if to.Len() < i+1 {
|
||||
to.Set(reflect.Append(to, dest))
|
||||
} else {
|
||||
if !set(to.Index(i), dest, opt.DeepCopy) {
|
||||
if !set(to.Index(i), dest, opt.DeepCopy, converters) {
|
||||
// ignore error while copy slice element
|
||||
err = copier(to.Index(i).Addr().Interface(), dest.Interface(), opt)
|
||||
if err != nil {
|
||||
@@ -334,6 +373,24 @@ func copier(toValue interface{}, fromValue interface{}, opt Option) (err error)
|
||||
return
|
||||
}
|
||||
|
||||
func copyUnexportedStructFields(to, from reflect.Value) {
|
||||
if from.Kind() != reflect.Struct || to.Kind() != reflect.Struct || !from.Type().AssignableTo(to.Type()) {
|
||||
return
|
||||
}
|
||||
|
||||
// create a shallow copy of 'to' to get all fields
|
||||
tmp := indirect(reflect.New(to.Type()))
|
||||
tmp.Set(from)
|
||||
|
||||
// revert exported fields
|
||||
for i := 0; i < to.NumField(); i++ {
|
||||
if tmp.Field(i).CanSet() {
|
||||
tmp.Field(i).Set(to.Field(i))
|
||||
}
|
||||
}
|
||||
to.Set(tmp)
|
||||
}
|
||||
|
||||
func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
|
||||
if !ignoreEmpty {
|
||||
return false
|
||||
@@ -352,10 +409,10 @@ func deepFields(reflectType reflect.Type) []reflect.StructField {
|
||||
// field name. It is empty for upper case (exported) field names.
|
||||
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
|
||||
if v.PkgPath == "" {
|
||||
fields = append(fields, v)
|
||||
if v.Anonymous {
|
||||
// also consider fields of anonymous fields as fields of the root
|
||||
fields = append(fields, deepFields(v.Type)...)
|
||||
} else {
|
||||
fields = append(fields, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -381,8 +438,14 @@ func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
|
||||
return reflectType, isPtr
|
||||
}
|
||||
|
||||
func set(to, from reflect.Value, deepCopy bool) bool {
|
||||
func set(to, from reflect.Value, deepCopy bool, converters map[converterPair]TypeConverter) bool {
|
||||
if from.IsValid() {
|
||||
if ok, err := lookupAndCopyWithConverter(to, from, converters); err != nil {
|
||||
return false
|
||||
} else if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if to.Kind() == reflect.Ptr {
|
||||
// set `to` to nil if from is nil
|
||||
if from.Kind() == reflect.Ptr && from.IsNil() {
|
||||
@@ -416,6 +479,9 @@ func set(to, from reflect.Value, deepCopy bool) bool {
|
||||
toKind = reflect.TypeOf(to.Interface()).Kind()
|
||||
}
|
||||
}
|
||||
if from.Kind() == reflect.Ptr && from.IsNil() {
|
||||
return true
|
||||
}
|
||||
if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
|
||||
return false
|
||||
}
|
||||
@@ -457,7 +523,7 @@ func set(to, from reflect.Value, deepCopy bool) bool {
|
||||
to.Set(rv)
|
||||
}
|
||||
} else if from.Kind() == reflect.Ptr {
|
||||
return set(to, from.Elem(), deepCopy)
|
||||
return set(to, from.Elem(), deepCopy, converters)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
@@ -466,6 +532,33 @@ func set(to, from reflect.Value, deepCopy bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// lookupAndCopyWithConverter looks up the type pair, on success the TypeConverter Fn func is called to copy src to dst field.
|
||||
func lookupAndCopyWithConverter(to, from reflect.Value, converters map[converterPair]TypeConverter) (copied bool, err error) {
|
||||
pair := converterPair{
|
||||
SrcType: from.Type(),
|
||||
DstType: to.Type(),
|
||||
}
|
||||
|
||||
if cnv, ok := converters[pair]; ok {
|
||||
result, err := cnv.Fn(from.Interface())
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
to.Set(reflect.ValueOf(result))
|
||||
} else {
|
||||
// in case we've got a nil value to copy
|
||||
to.Set(reflect.Zero(to.Type()))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// parseTags Parses struct tags and returns uint8 bit flags.
|
||||
func parseTags(tag string) (flg uint8, name string, err error) {
|
||||
for _, t := range strings.Split(tag, ",") {
|
||||
|
1
vendor/github.com/rs/xid/README.md
generated
vendored
1
vendor/github.com/rs/xid/README.md
generated
vendored
@@ -69,6 +69,7 @@ References:
|
||||
- Rust port by [Jérôme Renard](https://github.com/jeromer/): https://github.com/jeromer/libxid
|
||||
- Ruby port by [Valar](https://github.com/valarpirai/): https://github.com/valarpirai/ruby_xid
|
||||
- Java port by [0xShamil](https://github.com/0xShamil/): https://github.com/0xShamil/java-xid
|
||||
- Dart port by [Peter Bwire](https://github.com/pitabwire): https://pub.dev/packages/xid
|
||||
|
||||
## Install
|
||||
|
||||
|
11
vendor/github.com/rs/xid/error.go
generated
vendored
Normal file
11
vendor/github.com/rs/xid/error.go
generated
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
package xid
|
||||
|
||||
const (
|
||||
// ErrInvalidID is returned when trying to unmarshal an invalid ID.
|
||||
ErrInvalidID strErr = "xid: invalid ID"
|
||||
)
|
||||
|
||||
// strErr allows declaring errors as constants.
|
||||
type strErr string
|
||||
|
||||
func (err strErr) Error() string { return string(err) }
|
26
vendor/github.com/rs/xid/id.go
generated
vendored
26
vendor/github.com/rs/xid/id.go
generated
vendored
@@ -47,7 +47,6 @@ import (
|
||||
"crypto/rand"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io/ioutil"
|
||||
@@ -73,9 +72,6 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidID is returned when trying to unmarshal an invalid ID
|
||||
ErrInvalidID = errors.New("xid: invalid ID")
|
||||
|
||||
// objectIDCounter is atomically incremented when generating a new ObjectId
|
||||
// using NewObjectId() function. It's used as a counter part of an id.
|
||||
// This id is initialized with a random value.
|
||||
@@ -242,7 +238,9 @@ func (id *ID) UnmarshalText(text []byte) error {
|
||||
return ErrInvalidID
|
||||
}
|
||||
}
|
||||
decode(id, text)
|
||||
if !decode(id, text) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -253,11 +251,15 @@ func (id *ID) UnmarshalJSON(b []byte) error {
|
||||
*id = nilID
|
||||
return nil
|
||||
}
|
||||
// Check the slice length to prevent panic on passing it to UnmarshalText()
|
||||
if len(b) < 2 {
|
||||
return ErrInvalidID
|
||||
}
|
||||
return id.UnmarshalText(b[1 : len(b)-1])
|
||||
}
|
||||
|
||||
// decode by unrolling the stdlib base32 algorithm + removing all safe checks
|
||||
func decode(id *ID, src []byte) {
|
||||
// decode by unrolling the stdlib base32 algorithm + customized safe check.
|
||||
func decode(id *ID, src []byte) bool {
|
||||
_ = src[19]
|
||||
_ = id[11]
|
||||
|
||||
@@ -273,6 +275,16 @@ func decode(id *ID, src []byte) {
|
||||
id[2] = dec[src[3]]<<4 | dec[src[4]]>>1
|
||||
id[1] = dec[src[1]]<<6 | dec[src[2]]<<1 | dec[src[3]]>>4
|
||||
id[0] = dec[src[0]]<<3 | dec[src[1]]>>2
|
||||
|
||||
// Validate that there are no discarer bits (padding) in src that would
|
||||
// cause the string-encoded id not to equal src.
|
||||
var check [4]byte
|
||||
|
||||
check[3] = encoding[(id[11]<<4)&0x1F]
|
||||
check[2] = encoding[(id[11]>>1)&0x1F]
|
||||
check[1] = encoding[(id[11]>>6)&0x1F|(id[10]<<2)&0x1F]
|
||||
check[0] = encoding[id[10]>>3]
|
||||
return bytes.Equal([]byte(src[16:20]), check[:])
|
||||
}
|
||||
|
||||
// Time returns the timestamp part of the id.
|
||||
|
54
vendor/github.com/stretchr/testify/assert/assertion_compare.go
generated
vendored
54
vendor/github.com/stretchr/testify/assert/assertion_compare.go
generated
vendored
@@ -3,6 +3,7 @@ package assert
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CompareType int
|
||||
@@ -30,6 +31,8 @@ var (
|
||||
float64Type = reflect.TypeOf(float64(1))
|
||||
|
||||
stringType = reflect.TypeOf("")
|
||||
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
)
|
||||
|
||||
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||
@@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
// Check for known struct types we can check for compare results.
|
||||
case reflect.Struct:
|
||||
{
|
||||
// All structs enter here. We're not interested in most types.
|
||||
if !canConvert(obj1Value, timeType) {
|
||||
break
|
||||
}
|
||||
|
||||
// time.Time can compared!
|
||||
timeObj1, ok := obj1.(time.Time)
|
||||
if !ok {
|
||||
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
timeObj2, ok := obj2.(time.Time)
|
||||
if !ok {
|
||||
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
|
||||
}
|
||||
}
|
||||
|
||||
return compareEqual, false
|
||||
@@ -310,7 +334,10 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||
// assert.Greater(t, float64(2), float64(1))
|
||||
// assert.Greater(t, "b", "a")
|
||||
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// GreaterOrEqual asserts that the first element is greater than or equal to the second
|
||||
@@ -320,7 +347,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
|
||||
// assert.GreaterOrEqual(t, "b", "a")
|
||||
// assert.GreaterOrEqual(t, "b", "b")
|
||||
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Less asserts that the first element is less than the second
|
||||
@@ -329,7 +359,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
|
||||
// assert.Less(t, float64(1), float64(2))
|
||||
// assert.Less(t, "a", "b")
|
||||
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// LessOrEqual asserts that the first element is less than or equal to the second
|
||||
@@ -339,7 +372,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
|
||||
// assert.LessOrEqual(t, "a", "b")
|
||||
// assert.LessOrEqual(t, "b", "b")
|
||||
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Positive asserts that the specified element is positive
|
||||
@@ -347,8 +383,11 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
|
||||
// assert.Positive(t, 1)
|
||||
// assert.Positive(t, 1.23)
|
||||
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs)
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Negative asserts that the specified element is negative
|
||||
@@ -356,8 +395,11 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
// assert.Negative(t, -1)
|
||||
// assert.Negative(t, -1.23)
|
||||
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs)
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
|
||||
}
|
||||
|
||||
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
|
||||
|
16
vendor/github.com/stretchr/testify/assert/assertion_compare_can_convert.go
generated
vendored
Normal file
16
vendor/github.com/stretchr/testify/assert/assertion_compare_can_convert.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
// TODO: once support for Go 1.16 is dropped, this file can be
|
||||
// merged/removed with assertion_compare_go1.17_test.go and
|
||||
// assertion_compare_legacy.go
|
||||
|
||||
package assert
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Wrapper around reflect.Value.CanConvert, for compatability
|
||||
// reasons.
|
||||
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||
return value.CanConvert(to)
|
||||
}
|
16
vendor/github.com/stretchr/testify/assert/assertion_compare_legacy.go
generated
vendored
Normal file
16
vendor/github.com/stretchr/testify/assert/assertion_compare_legacy.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !go1.17
|
||||
// +build !go1.17
|
||||
|
||||
// TODO: once support for Go 1.16 is dropped, this file can be
|
||||
// merged/removed with assertion_compare_go1.17_test.go and
|
||||
// assertion_compare_can_convert.go
|
||||
|
||||
package assert
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Older versions of Go does not have the reflect.Value.CanConvert
|
||||
// method.
|
||||
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||
return false
|
||||
}
|
12
vendor/github.com/stretchr/testify/assert/assertion_format.go
generated
vendored
12
vendor/github.com/stretchr/testify/assert/assertion_format.go
generated
vendored
@@ -123,6 +123,18 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
|
||||
return ErrorAs(t, err, target, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorIsf asserts that at least one of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool {
|
||||
|
24
vendor/github.com/stretchr/testify/assert/assertion_forward.go
generated
vendored
24
vendor/github.com/stretchr/testify/assert/assertion_forward.go
generated
vendored
@@ -222,6 +222,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
|
||||
return ErrorAsf(a.t, err, target, msg, args...)
|
||||
}
|
||||
|
||||
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// a.ErrorContains(err, expectedErrorSubString)
|
||||
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := a.t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorContains(a.t, theError, contains, msgAndArgs...)
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) bool {
|
||||
if h, ok := a.t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorContainsf(a.t, theError, contains, msg, args...)
|
||||
}
|
||||
|
||||
// ErrorIs asserts that at least one of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) bool {
|
||||
|
8
vendor/github.com/stretchr/testify/assert/assertion_order.go
generated
vendored
8
vendor/github.com/stretchr/testify/assert/assertion_order.go
generated
vendored
@@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
|
||||
// assert.IsIncreasing(t, []float{1, 2})
|
||||
// assert.IsIncreasing(t, []string{"a", "b"})
|
||||
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs)
|
||||
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsNonIncreasing asserts that the collection is not increasing
|
||||
@@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
|
||||
// assert.IsNonIncreasing(t, []float{2, 1})
|
||||
// assert.IsNonIncreasing(t, []string{"b", "a"})
|
||||
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs)
|
||||
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsDecreasing asserts that the collection is decreasing
|
||||
@@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{})
|
||||
// assert.IsDecreasing(t, []float{2, 1})
|
||||
// assert.IsDecreasing(t, []string{"b", "a"})
|
||||
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs)
|
||||
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// IsNonDecreasing asserts that the collection is not decreasing
|
||||
@@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo
|
||||
// assert.IsNonDecreasing(t, []float{1, 2})
|
||||
// assert.IsNonDecreasing(t, []string{"a", "b"})
|
||||
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
|
||||
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs)
|
||||
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
112
vendor/github.com/stretchr/testify/assert/assertions.go
generated
vendored
112
vendor/github.com/stretchr/testify/assert/assertions.go
generated
vendored
@@ -718,10 +718,14 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
|
||||
// return (false, false) if impossible.
|
||||
// return (true, false) if element was not found.
|
||||
// return (true, true) if element was found.
|
||||
func includeElement(list interface{}, element interface{}) (ok, found bool) {
|
||||
func containsElement(list interface{}, element interface{}) (ok, found bool) {
|
||||
|
||||
listValue := reflect.ValueOf(list)
|
||||
listKind := reflect.TypeOf(list).Kind()
|
||||
listType := reflect.TypeOf(list)
|
||||
if listType == nil {
|
||||
return false, false
|
||||
}
|
||||
listKind := listType.Kind()
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
ok = false
|
||||
@@ -764,7 +768,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
ok, found := includeElement(s, contains)
|
||||
ok, found := containsElement(s, contains)
|
||||
if !ok {
|
||||
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
|
||||
}
|
||||
@@ -787,7 +791,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
ok, found := includeElement(s, contains)
|
||||
ok, found := containsElement(s, contains)
|
||||
if !ok {
|
||||
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
|
||||
}
|
||||
@@ -831,7 +835,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
|
||||
|
||||
for i := 0; i < subsetValue.Len(); i++ {
|
||||
element := subsetValue.Index(i).Interface()
|
||||
ok, found := includeElement(list, element)
|
||||
ok, found := containsElement(list, element)
|
||||
if !ok {
|
||||
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
|
||||
}
|
||||
@@ -852,7 +856,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
|
||||
h.Helper()
|
||||
}
|
||||
if subset == nil {
|
||||
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
|
||||
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
|
||||
}
|
||||
|
||||
subsetValue := reflect.ValueOf(subset)
|
||||
@@ -875,7 +879,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
|
||||
|
||||
for i := 0; i < subsetValue.Len(); i++ {
|
||||
element := subsetValue.Index(i).Interface()
|
||||
ok, found := includeElement(list, element)
|
||||
ok, found := containsElement(list, element)
|
||||
if !ok {
|
||||
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
|
||||
}
|
||||
@@ -1000,27 +1004,21 @@ func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
|
||||
type PanicTestFunc func()
|
||||
|
||||
// didPanic returns true if the function passed to it panics. Otherwise, it returns false.
|
||||
func didPanic(f PanicTestFunc) (bool, interface{}, string) {
|
||||
|
||||
didPanic := false
|
||||
var message interface{}
|
||||
var stack string
|
||||
func() {
|
||||
|
||||
defer func() {
|
||||
if message = recover(); message != nil {
|
||||
didPanic = true
|
||||
stack = string(debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
// call the target function
|
||||
f()
|
||||
func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string) {
|
||||
didPanic = true
|
||||
|
||||
defer func() {
|
||||
message = recover()
|
||||
if didPanic {
|
||||
stack = string(debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
return didPanic, message, stack
|
||||
// call the target function
|
||||
f()
|
||||
didPanic = false
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Panics asserts that the code inside the specified PanicTestFunc panics.
|
||||
@@ -1161,11 +1159,15 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
|
||||
bf, bok := toFloat(actual)
|
||||
|
||||
if !aok || !bok {
|
||||
return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...)
|
||||
return Fail(t, "Parameters must be numerical", msgAndArgs...)
|
||||
}
|
||||
|
||||
if math.IsNaN(af) && math.IsNaN(bf) {
|
||||
return true
|
||||
}
|
||||
|
||||
if math.IsNaN(af) {
|
||||
return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...)
|
||||
return Fail(t, "Expected must not be NaN", msgAndArgs...)
|
||||
}
|
||||
|
||||
if math.IsNaN(bf) {
|
||||
@@ -1188,7 +1190,7 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
|
||||
if expected == nil || actual == nil ||
|
||||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
|
||||
reflect.TypeOf(expected).Kind() != reflect.Slice {
|
||||
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
|
||||
return Fail(t, "Parameters must be slice", msgAndArgs...)
|
||||
}
|
||||
|
||||
actualSlice := reflect.ValueOf(actual)
|
||||
@@ -1250,8 +1252,12 @@ func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, m
|
||||
|
||||
func calcRelativeError(expected, actual interface{}) (float64, error) {
|
||||
af, aok := toFloat(expected)
|
||||
if !aok {
|
||||
return 0, fmt.Errorf("expected value %q cannot be converted to float", expected)
|
||||
bf, bok := toFloat(actual)
|
||||
if !aok || !bok {
|
||||
return 0, fmt.Errorf("Parameters must be numerical")
|
||||
}
|
||||
if math.IsNaN(af) && math.IsNaN(bf) {
|
||||
return 0, nil
|
||||
}
|
||||
if math.IsNaN(af) {
|
||||
return 0, errors.New("expected value must not be NaN")
|
||||
@@ -1259,10 +1265,6 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
|
||||
if af == 0 {
|
||||
return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error")
|
||||
}
|
||||
bf, bok := toFloat(actual)
|
||||
if !bok {
|
||||
return 0, fmt.Errorf("actual value %q cannot be converted to float", actual)
|
||||
}
|
||||
if math.IsNaN(bf) {
|
||||
return 0, errors.New("actual value must not be NaN")
|
||||
}
|
||||
@@ -1298,7 +1300,7 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
|
||||
if expected == nil || actual == nil ||
|
||||
reflect.TypeOf(actual).Kind() != reflect.Slice ||
|
||||
reflect.TypeOf(expected).Kind() != reflect.Slice {
|
||||
return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...)
|
||||
return Fail(t, "Parameters must be slice", msgAndArgs...)
|
||||
}
|
||||
|
||||
actualSlice := reflect.ValueOf(actual)
|
||||
@@ -1375,6 +1377,27 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContains(t, err, expectedErrorSubString)
|
||||
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
if !Error(t, theError, msgAndArgs...) {
|
||||
return false
|
||||
}
|
||||
|
||||
actual := theError.Error()
|
||||
if !strings.Contains(actual, contains) {
|
||||
return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchRegexp return true if a specified regexp matches a string.
|
||||
func matchRegexp(rx interface{}, str interface{}) bool {
|
||||
|
||||
@@ -1588,12 +1611,17 @@ func diff(expected interface{}, actual interface{}) string {
|
||||
}
|
||||
|
||||
var e, a string
|
||||
if et != reflect.TypeOf("") {
|
||||
e = spewConfig.Sdump(expected)
|
||||
a = spewConfig.Sdump(actual)
|
||||
} else {
|
||||
|
||||
switch et {
|
||||
case reflect.TypeOf(""):
|
||||
e = reflect.ValueOf(expected).String()
|
||||
a = reflect.ValueOf(actual).String()
|
||||
case reflect.TypeOf(time.Time{}):
|
||||
e = spewConfigStringerEnabled.Sdump(expected)
|
||||
a = spewConfigStringerEnabled.Sdump(actual)
|
||||
default:
|
||||
e = spewConfig.Sdump(expected)
|
||||
a = spewConfig.Sdump(actual)
|
||||
}
|
||||
|
||||
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
|
||||
@@ -1625,6 +1653,14 @@ var spewConfig = spew.ConfigState{
|
||||
MaxDepth: 10,
|
||||
}
|
||||
|
||||
var spewConfigStringerEnabled = spew.ConfigState{
|
||||
Indent: " ",
|
||||
DisablePointerAddresses: true,
|
||||
DisableCapacities: true,
|
||||
SortKeys: true,
|
||||
MaxDepth: 10,
|
||||
}
|
||||
|
||||
type tHelper interface {
|
||||
Helper()
|
||||
}
|
||||
|
30
vendor/github.com/stretchr/testify/require/require.go
generated
vendored
30
vendor/github.com/stretchr/testify/require/require.go
generated
vendored
@@ -280,6 +280,36 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContains(t, err, expectedErrorSubString)
|
||||
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
if assert.ErrorContains(t, theError, contains, msgAndArgs...) {
|
||||
return
|
||||
}
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
if assert.ErrorContainsf(t, theError, contains, msg, args...) {
|
||||
return
|
||||
}
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// ErrorIs asserts that at least one of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func ErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) {
|
||||
|
24
vendor/github.com/stretchr/testify/require/require_forward.go
generated
vendored
24
vendor/github.com/stretchr/testify/require/require_forward.go
generated
vendored
@@ -223,6 +223,30 @@ func (a *Assertions) ErrorAsf(err error, target interface{}, msg string, args ..
|
||||
ErrorAsf(a.t, err, target, msg, args...)
|
||||
}
|
||||
|
||||
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// a.ErrorContains(err, expectedErrorSubString)
|
||||
func (a *Assertions) ErrorContains(theError error, contains string, msgAndArgs ...interface{}) {
|
||||
if h, ok := a.t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
ErrorContains(a.t, theError, contains, msgAndArgs...)
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// a.ErrorContainsf(err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func (a *Assertions) ErrorContainsf(theError error, contains string, msg string, args ...interface{}) {
|
||||
if h, ok := a.t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
ErrorContainsf(a.t, theError, contains, msg, args...)
|
||||
}
|
||||
|
||||
// ErrorIs asserts that at least one of the errors in err's chain matches target.
|
||||
// This is a wrapper for errors.Is.
|
||||
func (a *Assertions) ErrorIs(err error, target error, msgAndArgs ...interface{}) {
|
||||
|
2
vendor/go.etcd.io/bbolt/.gitignore
generated
vendored
2
vendor/go.etcd.io/bbolt/.gitignore
generated
vendored
@@ -3,5 +3,3 @@
|
||||
*.swp
|
||||
/bin/
|
||||
cover.out
|
||||
/.idea
|
||||
*.iml
|
||||
|
3
vendor/go.etcd.io/bbolt/.travis.yml
generated
vendored
3
vendor/go.etcd.io/bbolt/.travis.yml
generated
vendored
@@ -4,10 +4,9 @@ go_import_path: go.etcd.io/bbolt
|
||||
sudo: false
|
||||
|
||||
go:
|
||||
- 1.15
|
||||
- 1.12
|
||||
|
||||
before_install:
|
||||
- go get -v golang.org/x/sys/unix
|
||||
- go get -v honnef.co/go/tools/...
|
||||
- go get -v github.com/kisielk/errcheck
|
||||
|
||||
|
2
vendor/go.etcd.io/bbolt/Makefile
generated
vendored
2
vendor/go.etcd.io/bbolt/Makefile
generated
vendored
@@ -2,6 +2,8 @@ BRANCH=`git rev-parse --abbrev-ref HEAD`
|
||||
COMMIT=`git rev-parse --short HEAD`
|
||||
GOLDFLAGS="-X main.branch $(BRANCH) -X main.commit $(COMMIT)"
|
||||
|
||||
default: build
|
||||
|
||||
race:
|
||||
@TEST_FREELIST_TYPE=hashmap go test -v -race -test.run="TestSimulate_(100op|1000op)"
|
||||
@echo "array freelist test"
|
||||
|
5
vendor/go.etcd.io/bbolt/README.md
generated
vendored
5
vendor/go.etcd.io/bbolt/README.md
generated
vendored
@@ -908,14 +908,12 @@ Below is a list of public, open source projects that use Bolt:
|
||||
* [BoltStore](https://github.com/yosssi/boltstore) - Session store using Bolt.
|
||||
* [Boltdb Boilerplate](https://github.com/bobintornado/boltdb-boilerplate) - Boilerplate wrapper around bolt aiming to make simple calls one-liners.
|
||||
* [BoltDbWeb](https://github.com/evnix/boltdbweb) - A web based GUI for BoltDB files.
|
||||
* [BoltDB Viewer](https://github.com/zc310/rich_boltdb) - A BoltDB Viewer Can run on Windows、Linux、Android system.
|
||||
* [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend.
|
||||
* [btcwallet](https://github.com/btcsuite/btcwallet) - A bitcoin wallet.
|
||||
* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining
|
||||
simple tx and key scans.
|
||||
* [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend.
|
||||
* [ChainStore](https://github.com/pressly/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations.
|
||||
* [🌰 Chestnut](https://github.com/jrapoport/chestnut) - Chestnut is encrypted storage for Go.
|
||||
* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware.
|
||||
* [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb.
|
||||
* [dcrwallet](https://github.com/decred/dcrwallet) - A wallet for the Decred cryptocurrency.
|
||||
@@ -940,8 +938,9 @@ Below is a list of public, open source projects that use Bolt:
|
||||
* [MetricBase](https://github.com/msiebuhr/MetricBase) - Single-binary version of Graphite.
|
||||
* [MuLiFS](https://github.com/dankomiocevic/mulifs) - Music Library Filesystem creates a filesystem to organise your music files.
|
||||
* [NATS](https://github.com/nats-io/nats-streaming-server) - NATS Streaming uses bbolt for message and metadata storage.
|
||||
* [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard.
|
||||
* [photosite/session](https://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site.
|
||||
* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system.
|
||||
* [Rain](https://github.com/cenkalti/rain) - BitTorrent client and library.
|
||||
* [reef-pi](https://github.com/reef-pi/reef-pi) - reef-pi is an award winning, modular, DIY reef tank controller using easy to learn electronics based on a Raspberry Pi.
|
||||
* [Request Baskets](https://github.com/darklynx/request-baskets) - A web service to collect arbitrary HTTP requests and inspect them via REST API or simple web UI, similar to [RequestBin](http://requestb.in/) service
|
||||
* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read.
|
||||
|
17
vendor/go.etcd.io/bbolt/bolt_unix.go
generated
vendored
17
vendor/go.etcd.io/bbolt/bolt_unix.go
generated
vendored
@@ -7,8 +7,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
@@ -51,13 +49,13 @@ func funlock(db *DB) error {
|
||||
// mmap memory maps a DB's data file.
|
||||
func mmap(db *DB, sz int) error {
|
||||
// Map the data file to memory.
|
||||
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Advise the kernel that the mmap is accessed randomly.
|
||||
err = unix.Madvise(b, syscall.MADV_RANDOM)
|
||||
err = madvise(b, syscall.MADV_RANDOM)
|
||||
if err != nil && err != syscall.ENOSYS {
|
||||
// Ignore not implemented error in kernel because it still works.
|
||||
return fmt.Errorf("madvise: %s", err)
|
||||
@@ -78,9 +76,18 @@ func munmap(db *DB) error {
|
||||
}
|
||||
|
||||
// Unmap using the original byte slice.
|
||||
err := unix.Munmap(db.dataref)
|
||||
err := syscall.Munmap(db.dataref)
|
||||
db.dataref = nil
|
||||
db.data = nil
|
||||
db.datasz = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTE: This function is copied from stdlib because it is not available on darwin.
|
||||
func madvise(b []byte, advice int) (err error) {
|
||||
_, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice))
|
||||
if e1 != 0 {
|
||||
err = e1
|
||||
}
|
||||
return
|
||||
}
|
||||
|
114
vendor/go.etcd.io/bbolt/compact.go
generated
vendored
114
vendor/go.etcd.io/bbolt/compact.go
generated
vendored
@@ -1,114 +0,0 @@
|
||||
package bbolt
|
||||
|
||||
// Compact will create a copy of the source DB and in the destination DB. This may
|
||||
// reclaim space that the source database no longer has use for. txMaxSize can be
|
||||
// used to limit the transactions size of this process and may trigger intermittent
|
||||
// commits. A value of zero will ignore transaction sizes.
|
||||
// TODO: merge with: https://github.com/etcd-io/etcd/blob/b7f0f52a16dbf83f18ca1d803f7892d750366a94/mvcc/backend/backend.go#L349
|
||||
func Compact(dst, src *DB, txMaxSize int64) error {
|
||||
// commit regularly, or we'll run out of memory for large datasets if using one transaction.
|
||||
var size int64
|
||||
tx, err := dst.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if err := walk(src, func(keys [][]byte, k, v []byte, seq uint64) error {
|
||||
// On each key/value, check if we have exceeded tx size.
|
||||
sz := int64(len(k) + len(v))
|
||||
if size+sz > txMaxSize && txMaxSize != 0 {
|
||||
// Commit previous transaction.
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start new transaction.
|
||||
tx, err = dst.Begin(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size = 0
|
||||
}
|
||||
size += sz
|
||||
|
||||
// Create bucket on the root transaction if this is the first level.
|
||||
nk := len(keys)
|
||||
if nk == 0 {
|
||||
bkt, err := tx.CreateBucket(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bkt.SetSequence(seq); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create buckets on subsequent levels, if necessary.
|
||||
b := tx.Bucket(keys[0])
|
||||
if nk > 1 {
|
||||
for _, k := range keys[1:] {
|
||||
b = b.Bucket(k)
|
||||
}
|
||||
}
|
||||
|
||||
// Fill the entire page for best compaction.
|
||||
b.FillPercent = 1.0
|
||||
|
||||
// If there is no value then this is a bucket call.
|
||||
if v == nil {
|
||||
bkt, err := b.CreateBucket(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bkt.SetSequence(seq); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise treat it as a key/value pair.
|
||||
return b.Put(k, v)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// walkFunc is the type of the function called for keys (buckets and "normal"
|
||||
// values) discovered by Walk. keys is the list of keys to descend to the bucket
|
||||
// owning the discovered key/value pair k/v.
|
||||
type walkFunc func(keys [][]byte, k, v []byte, seq uint64) error
|
||||
|
||||
// walk walks recursively the bolt database db, calling walkFn for each key it finds.
|
||||
func walk(db *DB, walkFn walkFunc) error {
|
||||
return db.View(func(tx *Tx) error {
|
||||
return tx.ForEach(func(name []byte, b *Bucket) error {
|
||||
return walkBucket(b, nil, name, nil, b.Sequence(), walkFn)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func walkBucket(b *Bucket, keypath [][]byte, k, v []byte, seq uint64, fn walkFunc) error {
|
||||
// Execute callback.
|
||||
if err := fn(keypath, k, v, seq); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If this is not a bucket then stop.
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate over each child key/value.
|
||||
keypath = append(keypath, k)
|
||||
return b.ForEach(func(k, v []byte) error {
|
||||
if v == nil {
|
||||
bkt := b.Bucket(k)
|
||||
return walkBucket(bkt, keypath, k, nil, bkt.Sequence(), fn)
|
||||
}
|
||||
return walkBucket(b, keypath, k, v, b.Sequence(), fn)
|
||||
})
|
||||
}
|
66
vendor/go.etcd.io/bbolt/db.go
generated
vendored
66
vendor/go.etcd.io/bbolt/db.go
generated
vendored
@@ -120,12 +120,6 @@ type DB struct {
|
||||
// of truncate() and fsync() when growing the data file.
|
||||
AllocSize int
|
||||
|
||||
// Mlock locks database file in memory when set to true.
|
||||
// It prevents major page faults, however used memory can't be reclaimed.
|
||||
//
|
||||
// Supported only on Unix via mlock/munlock syscalls.
|
||||
Mlock bool
|
||||
|
||||
path string
|
||||
openFile func(string, int, os.FileMode) (*os.File, error)
|
||||
file *os.File
|
||||
@@ -194,7 +188,6 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) {
|
||||
db.MmapFlags = options.MmapFlags
|
||||
db.NoFreelistSync = options.NoFreelistSync
|
||||
db.FreelistType = options.FreelistType
|
||||
db.Mlock = options.Mlock
|
||||
|
||||
// Set default values for later DB operations.
|
||||
db.MaxBatchSize = DefaultMaxBatchSize
|
||||
@@ -344,8 +337,7 @@ func (db *DB) mmap(minsz int) error {
|
||||
}
|
||||
|
||||
// Ensure the size is at least the minimum size.
|
||||
fileSize := int(info.Size())
|
||||
var size = fileSize
|
||||
var size = int(info.Size())
|
||||
if size < minsz {
|
||||
size = minsz
|
||||
}
|
||||
@@ -354,13 +346,6 @@ func (db *DB) mmap(minsz int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if db.Mlock {
|
||||
// Unlock db memory
|
||||
if err := db.munlock(fileSize); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Dereference all mmap references before unmapping.
|
||||
if db.rwtx != nil {
|
||||
db.rwtx.root.dereference()
|
||||
@@ -376,13 +361,6 @@ func (db *DB) mmap(minsz int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if db.Mlock {
|
||||
// Don't allow swapping of data file
|
||||
if err := db.mlock(fileSize); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Save references to the meta pages.
|
||||
db.meta0 = db.page(0).meta()
|
||||
db.meta1 = db.page(1).meta()
|
||||
@@ -444,36 +422,12 @@ func (db *DB) mmapSize(size int) (int, error) {
|
||||
return int(sz), nil
|
||||
}
|
||||
|
||||
func (db *DB) munlock(fileSize int) error {
|
||||
if err := munlock(db, fileSize); err != nil {
|
||||
return fmt.Errorf("munlock error: " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) mlock(fileSize int) error {
|
||||
if err := mlock(db, fileSize); err != nil {
|
||||
return fmt.Errorf("mlock error: " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) mrelock(fileSizeFrom, fileSizeTo int) error {
|
||||
if err := db.munlock(fileSizeFrom); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := db.mlock(fileSizeTo); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// init creates a new database file and initializes its meta pages.
|
||||
func (db *DB) init() error {
|
||||
// Create two meta pages on a buffer.
|
||||
buf := make([]byte, db.pageSize*4)
|
||||
for i := 0; i < 2; i++ {
|
||||
p := db.pageInBuffer(buf, pgid(i))
|
||||
p := db.pageInBuffer(buf[:], pgid(i))
|
||||
p.id = pgid(i)
|
||||
p.flags = metaPageFlag
|
||||
|
||||
@@ -490,13 +444,13 @@ func (db *DB) init() error {
|
||||
}
|
||||
|
||||
// Write an empty freelist at page 3.
|
||||
p := db.pageInBuffer(buf, pgid(2))
|
||||
p := db.pageInBuffer(buf[:], pgid(2))
|
||||
p.id = pgid(2)
|
||||
p.flags = freelistPageFlag
|
||||
p.count = 0
|
||||
|
||||
// Write an empty leaf page at page 4.
|
||||
p = db.pageInBuffer(buf, pgid(3))
|
||||
p = db.pageInBuffer(buf[:], pgid(3))
|
||||
p.id = pgid(3)
|
||||
p.flags = leafPageFlag
|
||||
p.count = 0
|
||||
@@ -508,7 +462,6 @@ func (db *DB) init() error {
|
||||
if err := fdatasync(db); err != nil {
|
||||
return err
|
||||
}
|
||||
db.filesz = len(buf)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1020,12 +973,6 @@ func (db *DB) grow(sz int) error {
|
||||
if err := db.file.Sync(); err != nil {
|
||||
return fmt.Errorf("file sync error: %s", err)
|
||||
}
|
||||
if db.Mlock {
|
||||
// unlock old file and lock new one
|
||||
if err := db.mrelock(db.filesz, sz); err != nil {
|
||||
return fmt.Errorf("mlock/munlock error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.filesz = sz
|
||||
@@ -1117,11 +1064,6 @@ type Options struct {
|
||||
// OpenFile is used to open files. It defaults to os.OpenFile. This option
|
||||
// is useful for writing hermetic tests.
|
||||
OpenFile func(string, int, os.FileMode) (*os.File, error)
|
||||
|
||||
// Mlock locks database file in memory when set to true.
|
||||
// It prevents potential page faults, however
|
||||
// used memory can't be reclaimed. (UNIX only)
|
||||
Mlock bool
|
||||
}
|
||||
|
||||
// DefaultOptions represent the options used if nil options are passed into Open().
|
||||
|
6
vendor/go.etcd.io/bbolt/freelist_hmap.go
generated
vendored
6
vendor/go.etcd.io/bbolt/freelist_hmap.go
generated
vendored
@@ -4,7 +4,7 @@ import "sort"
|
||||
|
||||
// hashmapFreeCount returns count of free pages(hashmap version)
|
||||
func (f *freelist) hashmapFreeCount() int {
|
||||
// use the forwardMap to get the total count
|
||||
// use the forwardmap to get the total count
|
||||
count := 0
|
||||
for _, size := range f.forwardMap {
|
||||
count += int(size)
|
||||
@@ -41,7 +41,7 @@ func (f *freelist) hashmapAllocate(txid txid, n int) pgid {
|
||||
|
||||
for pid := range bm {
|
||||
// remove the initial
|
||||
f.delSpan(pid, size)
|
||||
f.delSpan(pid, uint64(size))
|
||||
|
||||
f.allocs[pid] = txid
|
||||
|
||||
@@ -51,7 +51,7 @@ func (f *freelist) hashmapAllocate(txid txid, n int) pgid {
|
||||
f.addSpan(pid+pgid(n), remain)
|
||||
|
||||
for i := pgid(0); i < pgid(n); i++ {
|
||||
delete(f.cache, pid+i)
|
||||
delete(f.cache, pid+pgid(i))
|
||||
}
|
||||
return pid
|
||||
}
|
||||
|
36
vendor/go.etcd.io/bbolt/mlock_unix.go
generated
vendored
36
vendor/go.etcd.io/bbolt/mlock_unix.go
generated
vendored
@@ -1,36 +0,0 @@
|
||||
// +build !windows
|
||||
|
||||
package bbolt
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
// mlock locks memory of db file
|
||||
func mlock(db *DB, fileSize int) error {
|
||||
sizeToLock := fileSize
|
||||
if sizeToLock > db.datasz {
|
||||
// Can't lock more than mmaped slice
|
||||
sizeToLock = db.datasz
|
||||
}
|
||||
if err := unix.Mlock(db.dataref[:sizeToLock]); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//munlock unlocks memory of db file
|
||||
func munlock(db *DB, fileSize int) error {
|
||||
if db.dataref == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sizeToUnlock := fileSize
|
||||
if sizeToUnlock > db.datasz {
|
||||
// Can't unlock more than mmaped slice
|
||||
sizeToUnlock = db.datasz
|
||||
}
|
||||
|
||||
if err := unix.Munlock(db.dataref[:sizeToUnlock]); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
11
vendor/go.etcd.io/bbolt/mlock_windows.go
generated
vendored
11
vendor/go.etcd.io/bbolt/mlock_windows.go
generated
vendored
@@ -1,11 +0,0 @@
|
||||
package bbolt
|
||||
|
||||
// mlock locks memory of db file
|
||||
func mlock(_ *DB, _ int) error {
|
||||
panic("mlock is supported only on UNIX systems")
|
||||
}
|
||||
|
||||
//munlock unlocks memory of db file
|
||||
func munlock(_ *DB, _ int) error {
|
||||
panic("munlock is supported only on UNIX systems")
|
||||
}
|
3
vendor/go.etcd.io/bbolt/tx.go
generated
vendored
3
vendor/go.etcd.io/bbolt/tx.go
generated
vendored
@@ -188,6 +188,7 @@ func (tx *Tx) Commit() error {
|
||||
}
|
||||
|
||||
// If strict mode is enabled then perform a consistency check.
|
||||
// Only the first consistency error is reported in the panic.
|
||||
if tx.db.StrictMode {
|
||||
ch := tx.Check()
|
||||
var errs []string
|
||||
@@ -392,7 +393,7 @@ func (tx *Tx) CopyFile(path string, mode os.FileMode) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.WriteTo(f)
|
||||
err = tx.Copy(f)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return err
|
||||
|
10
vendor/modules.txt
vendored
10
vendor/modules.txt
vendored
@@ -12,10 +12,10 @@ github.com/asdine/storm/v3/q
|
||||
# github.com/davecgh/go-spew v1.1.1
|
||||
## explicit
|
||||
github.com/davecgh/go-spew/spew
|
||||
# github.com/gorilla/websocket v1.4.2
|
||||
# github.com/gorilla/websocket v1.5.0
|
||||
## explicit; go 1.12
|
||||
github.com/gorilla/websocket
|
||||
# github.com/jinzhu/copier v0.3.4
|
||||
# github.com/jinzhu/copier v0.3.5
|
||||
## explicit; go 1.13
|
||||
github.com/jinzhu/copier
|
||||
# github.com/logrusorgru/aurora v2.0.3+incompatible
|
||||
@@ -24,14 +24,14 @@ github.com/logrusorgru/aurora
|
||||
# github.com/pmezard/go-difflib v1.0.0
|
||||
## explicit
|
||||
github.com/pmezard/go-difflib/difflib
|
||||
# github.com/rs/xid v1.3.0
|
||||
# github.com/rs/xid v1.4.0
|
||||
## explicit; go 1.12
|
||||
github.com/rs/xid
|
||||
# github.com/stretchr/testify v1.7.0
|
||||
# github.com/stretchr/testify v1.7.1
|
||||
## explicit; go 1.13
|
||||
github.com/stretchr/testify/assert
|
||||
github.com/stretchr/testify/require
|
||||
# go.etcd.io/bbolt v1.3.6
|
||||
# go.etcd.io/bbolt v1.3.5
|
||||
## explicit; go 1.12
|
||||
go.etcd.io/bbolt
|
||||
# golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d
|
||||
|
Reference in New Issue
Block a user