mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-12-24 04:23:54 +08:00
Compare commits
91 Commits
v2.0.1
...
use-contex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef5dcf68d0 | ||
|
|
6704cf7227 | ||
|
|
9233e6fd39 | ||
|
|
1ca65d9631 | ||
|
|
33229da885 | ||
|
|
c274d5fd08 | ||
|
|
10e82f41d6 | ||
|
|
e6c07b2b78 | ||
|
|
eed3ef9606 | ||
|
|
1ec880844d | ||
|
|
4b49652a8c | ||
|
|
d46e7b5bcf | ||
|
|
17fb7dadbc | ||
|
|
ed7fd836e1 | ||
|
|
605bb93c75 | ||
|
|
c73ace2ea0 | ||
|
|
aac6d699da | ||
|
|
7bd7bd5087 | ||
|
|
655bf9fdb1 | ||
|
|
b188055c7d | ||
|
|
aaf1d9d4c6 | ||
|
|
44ce819318 | ||
|
|
e4c76cc60c | ||
|
|
da79faa972 | ||
|
|
46babc89c8 | ||
|
|
9b7a943888 | ||
|
|
a909d30923 | ||
|
|
0851b09e4d | ||
|
|
a302c9dd88 | ||
|
|
1e8f922102 | ||
|
|
4c16e5593f | ||
|
|
49cada4fbc | ||
|
|
ef34510c0b | ||
|
|
e5716caad1 | ||
|
|
4b039cb35c | ||
|
|
aac245441a | ||
|
|
bb54cc68e6 | ||
|
|
7ba1352a60 | ||
|
|
ca849131eb | ||
|
|
ba7e534122 | ||
|
|
db760c34a5 | ||
|
|
ae3ee81bb4 | ||
|
|
c2ca02d149 | ||
|
|
77a64d9c87 | ||
|
|
8dec9cc962 | ||
|
|
f90e52328d | ||
|
|
50aae47618 | ||
|
|
0d79f2d63b | ||
|
|
300152413c | ||
|
|
0de1d731db | ||
|
|
80746abc52 | ||
|
|
a73cf4ca0e | ||
|
|
bc549ee7ed | ||
|
|
c464b46713 | ||
|
|
05ce56008c | ||
|
|
8254cb0cbc | ||
|
|
4ae58b79e3 | ||
|
|
b895d688e0 | ||
|
|
a600cd4ead | ||
|
|
cdb44990cf | ||
|
|
2d9c128111 | ||
|
|
a0d5bdb39f | ||
|
|
4ebcef3cb6 | ||
|
|
fb8d4720d7 | ||
|
|
4080c89127 | ||
|
|
1b67e6f3f6 | ||
|
|
1adb02e087 | ||
|
|
4d4140aa99 | ||
|
|
e31840a37d | ||
|
|
7d2e16f2d3 | ||
|
|
92cd935a16 | ||
|
|
25ce27ce2d | ||
|
|
527d084a4b | ||
|
|
bb9f937bb0 | ||
|
|
511fe88684 | ||
|
|
75504ff201 | ||
|
|
a556feb325 | ||
|
|
d06f47f4b9 | ||
|
|
8d4cc091b4 | ||
|
|
d8f28cb843 | ||
|
|
88861c219d | ||
|
|
7ba6cf28d9 | ||
|
|
c174cfdc6b | ||
|
|
4f198a99dd | ||
|
|
2a9c9fcc40 | ||
|
|
835a85c8bf | ||
|
|
fe5d9ffa61 | ||
|
|
aac186dcc1 | ||
|
|
42931f332f | ||
|
|
8a04648c09 | ||
|
|
854c033fb6 |
42
.github/workflows/build.yml
vendored
42
.github/workflows/build.yml
vendored
@@ -7,37 +7,39 @@ jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
|
||||
coverage:
|
||||
name: Test with Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
if: success()
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Calc coverage
|
||||
run: |
|
||||
go test -v -covermode=count -coverprofile=coverage.out ./...
|
||||
- name: Convert coverage.out to coverage.lcov
|
||||
uses: jandelgado/gcov2lcov-action@v1.0.6
|
||||
- name: Coveralls
|
||||
uses: coverallsapp/github-action@v1.1.2
|
||||
with:
|
||||
github-token: ${{ secrets.github_token }}
|
||||
path-to-lcov: coverage.lcov
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '1.19'
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
go mod download
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
go test -race -covermode atomic -coverprofile=covprofile ./...
|
||||
- name: Install goveralls
|
||||
run: go install github.com/mattn/goveralls@latest
|
||||
- name: Send coverage
|
||||
env:
|
||||
COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: goveralls -coverprofile=covprofile -service=github
|
||||
|
||||
84
README.md
84
README.md
@@ -2,9 +2,9 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt)
|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
|
||||
[](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||
</p>
|
||||
@@ -19,6 +19,11 @@ MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It
|
||||
## What's new in Version 2.0.0?
|
||||
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
|
||||
|
||||
Don't forget to use the new v2 import paths:
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/v2"
|
||||
```
|
||||
|
||||
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
|
||||
- User and MQTTv5 Packet Properties
|
||||
- Topic Aliases
|
||||
@@ -36,10 +41,10 @@ Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learn
|
||||
- Direct Packet Injection using special inline client, or masquerade as existing clients.
|
||||
- Performant and Stable:
|
||||
- Our classic trie-based Topic-Subscription model.
|
||||
- A new fixed 'FanPool' worker queues to ensure consistent resource allocation and throughput reliability.
|
||||
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
|
||||
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
|
||||
- Over a thousand carefully considered unit test scenarios.
|
||||
- TCP, Websocket, (including SSL/TLS) and $SYS Dashboard listeners.
|
||||
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
|
||||
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
|
||||
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
|
||||
|
||||
@@ -78,22 +83,26 @@ docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
Importing Mochi MQTT as a package requires just a few lines of code to get started.
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New(nil)
|
||||
|
||||
|
||||
// Allow all connections.
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.Serve()
|
||||
if err != nil {
|
||||
@@ -107,10 +116,15 @@ Examples of running the broker with various configurations can be found in the [
|
||||
#### Network Listeners
|
||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||
|
||||
- `listeners.NewTCP(...)` - A TCP listener.
|
||||
- `listeners.NewWebsocket(...)` A Websocket listener.
|
||||
- `listeners.NewHTTPStats(...)` An HTTP $SYS info dashboard.
|
||||
- Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
| Listener | Usage |
|
||||
| --- | --- |
|
||||
| listeners.NewTCP | A TCP listener |
|
||||
| listeners.NewUnixSock | A Unix Socket listener |
|
||||
| listeners.NewNet | A net.Listener listener |
|
||||
| listeners.NewWebsocket | A Websocket listener |
|
||||
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
|
||||
|
||||
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
|
||||
A `*listeners.Config` may be passed to configure TLS.
|
||||
|
||||
@@ -122,6 +136,8 @@ A number of configurable options are available which can be used to alter the be
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
Capabilities: mqtt.Capabilities{
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
MaximumSessionExpiryInterval: 3600,
|
||||
Compatibilities: mqtt.Compatibilities{
|
||||
ObscureNotAuthorized: true,
|
||||
@@ -131,7 +147,7 @@ server := mqtt.New(&mqtt.Options{
|
||||
})
|
||||
```
|
||||
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options.
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
|
||||
|
||||
|
||||
## Event Hooks
|
||||
@@ -283,15 +299,16 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
|
||||
| OnPublish | Called when a client publishes a message. Allows packet modification. |
|
||||
| OnPublished | Called when a client has published a message to subscribers. |
|
||||
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
|
||||
| OnRetainMessage | Called then a published message is retained. |
|
||||
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
|
||||
| OnQosComplete | Called when the Qos flow for a message has been completed. |
|
||||
| OnQosDropped | Called when an inflight message expires before completion. |
|
||||
| OnPacketIDExhausted | Called when a client runs out of unused packet ids to assign. |
|
||||
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
|
||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||
| OnExpireInflights | Called when the server issues a clear request for expired inflight messages.|
|
||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||
@@ -300,13 +317,22 @@ The function signatures for all the hooks and `mqtt.Hook` interface can be found
|
||||
|
||||
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
|
||||
|
||||
### Packet Injection
|
||||
It's also possible to inject custom MQTT packets directly into the runtime as though they had been received by a specific client. This special client is called an InlineClient, and it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
|
||||
Packet injection can be used with MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirement).
|
||||
### Direct Publish
|
||||
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
|
||||
|
||||
```go
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
|
||||
|
||||
### Packet Injection
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
|
||||
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
@@ -320,6 +346,8 @@ server.InjectPacket(cl, packets.Packet{
|
||||
|
||||
See the [hooks example](examples/hooks/main.go) to see this feature in action.
|
||||
|
||||
|
||||
|
||||
### Testing
|
||||
#### Unit Tests
|
||||
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
|
||||
@@ -344,23 +372,23 @@ Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inov
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 139,860 | 135,960 | 132,059 | 217,499 | 211,027 | 204,555 |
|
||||
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
|
||||
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17649 |
|
||||
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 55,189 | 34,840 | 21,298 | 56,980 | 28,557 | 23,781 |
|
||||
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3756 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
|
||||
Million Message Challenge (hit the server with 1 million messages immediately):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.0.0 | 13,573 | 3,678 | 1,848 | 34,309 | 2,470 | 5,636 |
|
||||
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
|
||||
|
||||
203
clients.go
203
clients.go
@@ -1,11 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -86,7 +88,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
||||
if client.Net.Listener == id && !client.Closed() {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
@@ -105,7 +107,7 @@ type Client struct {
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
conn net.Conn // the net.Conn used to establish the connection
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
@@ -134,30 +136,31 @@ type Will struct {
|
||||
|
||||
// State tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
endOnce sync.Once // only end once
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
open context.Context // indicate that the client is open for packet exchange
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
}
|
||||
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, o *ops) *Client {
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
cl := &Client{
|
||||
Net: ClientConnection{
|
||||
conn: c,
|
||||
bconn: bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.capabilities.TopicAliasMaximum),
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
keepalive: defaultKeepalive,
|
||||
open: context.Background(),
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
@@ -165,43 +168,29 @@ func NewClient(c net.Conn, o *ops) *Client {
|
||||
ops: o,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(
|
||||
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
|
||||
),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// NewInlineClient returns a client used when publishing from the embedding system.
|
||||
func NewInlineClient(id, remote string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Net: ClientConnection{
|
||||
Remote: remote,
|
||||
Inline: true,
|
||||
},
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newClientStub returns an instance of Client with minimal initializations, such as
|
||||
// restoring client data from a db. In particular, the client is marked as offline (done).
|
||||
func newClientStub() *Client {
|
||||
return &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(0),
|
||||
done: 1,
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
|
||||
func (cl *Client) WriteLoop() {
|
||||
for pk := range cl.State.outbound {
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,8 +203,8 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
|
||||
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
|
||||
|
||||
@@ -225,7 +214,7 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Properties.Props.AssignedClientID = cl.ID
|
||||
}
|
||||
|
||||
cl.State.keepalive = cl.ops.capabilities.ServerKeepAlive
|
||||
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
|
||||
if pk.Connect.Keepalive > 0 {
|
||||
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
}
|
||||
@@ -253,13 +242,13 @@ func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
if cl.Net.conn != nil {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
_ = cl.Net.conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,28 +256,30 @@ func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
// If no unused packet ids are available, an error is returned and the client
|
||||
// should be disconnected.
|
||||
func (cl *Client) NextPacketID() (i uint32, err error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
i = atomic.LoadUint32(&cl.State.packetID)
|
||||
started := i + 1
|
||||
started := i
|
||||
overflowed := false
|
||||
for {
|
||||
if i >= 65535 {
|
||||
overflowed = true
|
||||
i = 1
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
|
||||
if overflowed && i == started {
|
||||
return 0, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
if i >= cl.ops.options.Capabilities.maximumPacketID {
|
||||
overflowed = true
|
||||
i = 0
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
|
||||
break
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
|
||||
@@ -302,7 +293,7 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
|
||||
}
|
||||
|
||||
// cl.ops.hooks.OnQosPublish(cl, tk.Packet, nt, tk.Resends)
|
||||
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
|
||||
err := cl.WritePacket(tk)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -320,17 +311,18 @@ func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) int64 {
|
||||
var deleted int64
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted++
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
@@ -340,7 +332,7 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
if cl.Closed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -365,20 +357,28 @@ func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
if cl.Net.conn != nil {
|
||||
_ = cl.Net.conn.Close() // omit close error
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cl.State.stopCause.Store(err)
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.done, 1)
|
||||
if cl.State.outbound != nil {
|
||||
close(cl.State.outbound)
|
||||
}
|
||||
|
||||
if cl.State.open != nil {
|
||||
var cancel context.CancelFunc
|
||||
cl.State.open, cancel = context.WithCancel(cl.State.open)
|
||||
cancel()
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
@@ -391,6 +391,11 @@ func (cl *Client) StopCause() error {
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return cl.State.open == nil || cl.State.open.Err() != nil
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
if cl.Net.bconn == nil {
|
||||
@@ -413,6 +418,10 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
@@ -480,15 +489,6 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer cl.refreshDeadline(cl.State.keepalive)
|
||||
if pk.Expiry > 0 {
|
||||
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
|
||||
}
|
||||
@@ -502,8 +502,8 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // NB we need to know which properties we can encode
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
@@ -552,8 +552,19 @@ func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
|
||||
}
|
||||
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
if cl.Closed() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := nb.WriteTo(cl.Net.conn)
|
||||
n, err := nb.WriteTo(cl.Net.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
196
clients_test.go
196
clients_test.go
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@@ -21,16 +23,20 @@ const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = NewClient(w, &ops{
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -41,6 +47,9 @@ func newClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
go cl.WriteLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,8 +115,8 @@ func TestClientsDelete(t *testing.T) {
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
||||
cl.Add(&Client{ID: "t1", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", State: ClientState{open: context.Background()}, Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
@@ -118,34 +127,23 @@ func TestClientsGetByListener(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Nil(t, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestNewClientStub(t *testing.T) {
|
||||
cl := newClientStub()
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.State.done))
|
||||
}
|
||||
|
||||
func TestNewInlineClient(t *testing.T) {
|
||||
cl := NewInlineClient("inline", "local")
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.Equal(t, uint32(0), atomic.LoadUint32(&cl.State.done))
|
||||
require.Equal(t, "inline", cl.ID)
|
||||
require.Equal(t, "local", cl.Net.Remote)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.NotNil(t, cl.ops.options.Capabilities)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -175,14 +173,14 @@ func TestClientParseConnect(t *testing.T) {
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
@@ -207,13 +205,13 @@ func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
@@ -225,7 +223,7 @@ func TestClientNextPacketID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
@@ -248,33 +246,37 @@ func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
for i := 0; i <= 65535; i++ {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
require.Equal(t, uint32(0), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
|
||||
_, err = cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
|
||||
cl.State.packetID = uint32(65534)
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(65535), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
@@ -284,13 +286,15 @@ func TestClientClearInflights(t *testing.T) {
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
cl.ClearInflights(n, 4)
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newClient()
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
@@ -310,7 +314,7 @@ func TestClientResendInflightMessages(t *testing.T) {
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
@@ -322,19 +326,19 @@ func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.conn) // how do we check net.Conn deadline?
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
@@ -349,7 +353,7 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -362,8 +366,24 @@ func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -377,7 +397,7 @@ func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
@@ -391,7 +411,7 @@ func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -445,9 +465,9 @@ func TestClientReadOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
cl.State.open = nil
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
@@ -460,15 +480,23 @@ func TestClientReadDone(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.True(t, cl.Closed())
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientClosed(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.False(t, cl.Closed())
|
||||
cl.Stop(nil)
|
||||
require.True(t, cl.Closed())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -485,7 +513,7 @@ func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -505,7 +533,7 @@ func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -537,7 +565,7 @@ func TestClientReadReadPacketOK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
@@ -570,9 +598,17 @@ func TestClientReadPacket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
@@ -588,7 +624,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.conn.Close()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
@@ -612,7 +648,7 @@ func TestClientWritePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
@@ -621,7 +657,7 @@ func TestWriteClientOversizePacket(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -641,7 +677,7 @@ func TestClientReadPacketReadingError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newClient()
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
@@ -660,7 +696,7 @@ func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
@@ -669,15 +705,15 @@ func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl.Net.conn.Close()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
52
examples/benchmark/main.go
Normal file
52
examples/benchmark/main.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
|
||||
flag.Parse()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -35,7 +36,7 @@ func main() {
|
||||
}
|
||||
|
||||
err = server.AddHook(new(debug.Hook), &debug.Options{
|
||||
ShowPacketData: true,
|
||||
// ShowPacketData: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -51,15 +52,30 @@ func main() {
|
||||
// `server.Publish` method. Subscribe to `direct/publish` using your
|
||||
// MQTT client to see the messages.
|
||||
go func() {
|
||||
cl := mqtt.NewInlineClient("inline", "local")
|
||||
for range time.Tick(time.Second * 10) {
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
for range time.Tick(time.Second * 1) {
|
||||
err := server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.InjectPacket")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
// There is also a shorthand convenience function, Publish, for easily sending
|
||||
// publish packets if you are not concerned with creating your own packets.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.Publish")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -28,7 +29,6 @@ func main() {
|
||||
server.Options.Capabilities.ServerKeepAlive = 60
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
server.Options.Capabilities.Compatibilities.AlwaysReturnResponseInfo = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -29,12 +30,15 @@ func main() {
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(bolt.Hook), bolt.Options{
|
||||
err := server.AddHook(new(bolt.Hook), &bolt.Options{
|
||||
Path: "bolt.db",
|
||||
Options: &bbolt.Options{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
|
||||
100
fanpool.go
100
fanpool.go
@@ -1,100 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co, chowyu08, muXxer
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
xh "github.com/cespare/xxhash/v2"
|
||||
)
|
||||
|
||||
// taskChan is a channel for incoming task functions.
|
||||
type taskChan chan func()
|
||||
|
||||
// FanPool is a fixed-sized fan-style worker pool with multiple
|
||||
// working 'columns'. Instead of a single queue channel processed by
|
||||
// many goroutines, this fan pool uses many queue channels each
|
||||
// processed by a single goroutine.
|
||||
// Very special thanks are given to the authors of HMQ in particular
|
||||
// @chowyu08 and @muXxer for their work on the fixpool worker pool
|
||||
// https://github.com/fhmq/hmq/blob/master/pool/fixpool.go
|
||||
// from which this fan-pool is heavily inspired.
|
||||
type FanPool struct {
|
||||
queue []taskChan
|
||||
wg sync.WaitGroup
|
||||
capacity uint64
|
||||
perChan uint64
|
||||
Mutex sync.Mutex
|
||||
}
|
||||
|
||||
// New returns a new instance of FanPool. fanSize controls the number of 'columns'
|
||||
// of the fan, whereas queueSize controls the size of each column's queue.
|
||||
func NewFanPool(fanSize, queueSize uint64) *FanPool {
|
||||
pool := &FanPool{
|
||||
capacity: fanSize,
|
||||
perChan: queueSize,
|
||||
queue: make([]taskChan, fanSize),
|
||||
}
|
||||
|
||||
pool.fillWorkers(fanSize)
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// fillWorkers adds columns to the fan pool with an associated worker goroutine.
|
||||
func (p *FanPool) fillWorkers(n uint64) {
|
||||
for i := uint64(0); i < n; i++ {
|
||||
p.queue[i] = make(taskChan, p.perChan)
|
||||
go p.worker(p.queue[i])
|
||||
p.wg.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// worker is a worker goroutine which processes tasks from a single queue.
|
||||
func (p *FanPool) worker(ch taskChan) {
|
||||
defer p.wg.Done()
|
||||
var task func()
|
||||
var ok bool
|
||||
for {
|
||||
task, ok = <-ch
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
task()
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue adds a new task to the queue to be processed.
|
||||
func (p *FanPool) Enqueue(id string, task func()) {
|
||||
if p.Size() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// We can use xh.Sum64 to get a specific queue index
|
||||
// which remains the same for a client id, giving each
|
||||
// client their own queue.
|
||||
p.queue[xh.Sum64([]byte(id))%p.Size()] <- task
|
||||
}
|
||||
|
||||
// Wait blocks until all the workers in the pool have completed.
|
||||
func (p *FanPool) Wait() {
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
// Close issues a shutdown signal to the workers.
|
||||
func (p *FanPool) Close() {
|
||||
for i := 0; i < int(p.Size()); i++ {
|
||||
if p.queue[i] != nil {
|
||||
close(p.queue[i])
|
||||
}
|
||||
}
|
||||
p.queue = nil
|
||||
atomic.StoreUint64(&p.capacity, 0)
|
||||
}
|
||||
|
||||
// Size returns the current number of workers in the pool.
|
||||
func (p *FanPool) Size() uint64 {
|
||||
return atomic.LoadUint64(&p.capacity)
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFanPool(t *testing.T) {
|
||||
f := NewFanPool(1, 2)
|
||||
require.NotNil(t, f)
|
||||
require.Equal(t, uint64(1), f.capacity)
|
||||
require.Equal(t, 2, cap(f.queue[0]))
|
||||
|
||||
o := make(chan bool)
|
||||
go func() {
|
||||
f.Enqueue("test", func() {
|
||||
o <- true
|
||||
})
|
||||
}()
|
||||
|
||||
require.True(t, <-o)
|
||||
f.Close()
|
||||
f.Wait()
|
||||
}
|
||||
|
||||
func TestFillWorkers(t *testing.T) {
|
||||
f := &FanPool{
|
||||
perChan: 3,
|
||||
queue: make([]taskChan, 2),
|
||||
}
|
||||
f.fillWorkers(2)
|
||||
require.Len(t, f.queue, 2)
|
||||
require.Equal(t, 3, cap(f.queue[0]))
|
||||
}
|
||||
|
||||
func TestEnqueue(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 2,
|
||||
queue: []taskChan{
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
f.Enqueue("a", func() {})
|
||||
}()
|
||||
require.NotNil(t, <-f.queue[1])
|
||||
}
|
||||
|
||||
func TestEnqueueOnEmpty(t *testing.T) {
|
||||
f := &FanPool{
|
||||
queue: []taskChan{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
f.Enqueue("a", func() {})
|
||||
}()
|
||||
|
||||
require.Len(t, f.queue, 0)
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 10,
|
||||
}
|
||||
|
||||
require.Equal(t, uint64(10), f.Size())
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
f := &FanPool{
|
||||
capacity: 3,
|
||||
queue: []taskChan{
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
make(taskChan, 2),
|
||||
},
|
||||
}
|
||||
|
||||
f.Close()
|
||||
require.Equal(t, uint64(0), f.Size())
|
||||
require.Nil(t, f.queue)
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -6,7 +6,6 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.23.0
|
||||
github.com/asdine/storm v2.1.2+incompatible
|
||||
github.com/asdine/storm/v3 v3.2.1
|
||||
github.com/cespare/xxhash/v2 v2.1.2
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
@@ -21,6 +20,7 @@ require (
|
||||
require (
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/badger v1.6.0 // indirect
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
|
||||
@@ -33,8 +33,8 @@ require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc // indirect
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
|
||||
golang.org/x/net v0.7.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
)
|
||||
|
||||
10
go.sum
10
go.sum
@@ -109,8 +109,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc h1:FxpXZdoBqT8RjqTy6i1E8nXHhW21wK7ptQ/EPIGxzPQ=
|
||||
golang.org/x/net v0.0.0-20220927171203-f486391704dc/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -118,11 +118,11 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
|
||||
152
hooks.go
152
hooks.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -38,15 +39,16 @@ const (
|
||||
OnUnsubscribed
|
||||
OnPublish
|
||||
OnPublished
|
||||
OnPublishDropped
|
||||
OnRetainMessage
|
||||
OnQosPublish
|
||||
OnQosComplete
|
||||
OnQosDropped
|
||||
OnPacketIDExhausted
|
||||
OnWill
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
OnExpireInflights
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
@@ -87,15 +89,16 @@ type Hook interface {
|
||||
OnUnsubscribed(cl *Client, pk packets.Packet)
|
||||
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPublished(cl *Client, pk packets.Packet)
|
||||
OnPublishDropped(cl *Client, pk packets.Packet)
|
||||
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
|
||||
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
|
||||
OnQosComplete(cl *Client, pk packets.Packet)
|
||||
OnQosDropped(cl *Client, pk packets.Packet)
|
||||
OnPacketIDExhausted(cl *Client, pk packets.Packet)
|
||||
OnWill(cl *Client, will Will) (Will, error)
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
OnExpireInflights(cl *Client, expiry int64)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
@@ -111,10 +114,10 @@ type HookOptions struct {
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal []Hook // a slice of hooks
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
@@ -124,7 +127,7 @@ func (h *Hooks) Len() int64 {
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
@@ -139,26 +142,39 @@ func (h *Hooks) Provides(b ...byte) bool {
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
if h.internal == nil {
|
||||
h.internal = []Hook{}
|
||||
}
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
h.internal = append(h.internal, hook)
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
@@ -173,7 +189,7 @@ func (h *Hooks) Stop() {
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
@@ -182,7 +198,7 @@ func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
@@ -191,7 +207,7 @@ func (h *Hooks) OnStarted() {
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
@@ -200,7 +216,7 @@ func (h *Hooks) OnStopped() {
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
hook.OnConnect(cl, pk)
|
||||
}
|
||||
@@ -209,7 +225,7 @@ func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
@@ -218,7 +234,7 @@ func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
@@ -228,7 +244,7 @@ func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
@@ -249,7 +265,7 @@ func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
@@ -265,7 +281,7 @@ func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet,
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
@@ -276,7 +292,7 @@ func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
@@ -286,7 +302,7 @@ func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
@@ -298,7 +314,7 @@ func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
@@ -308,7 +324,7 @@ func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
@@ -320,7 +336,7 @@ func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
@@ -333,7 +349,7 @@ func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subsc
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
@@ -343,19 +359,19 @@ func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnMessage
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
@@ -374,16 +390,26 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublishDropped is called when a message to a client was dropped instead of delivered
|
||||
// such as when a client is too slow to respond.
|
||||
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublishDropped) {
|
||||
hook.OnPublishDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
@@ -394,7 +420,7 @@ func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
@@ -405,7 +431,7 @@ func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
@@ -413,22 +439,32 @@ func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to
|
||||
// assign to a packet.
|
||||
func (h *Hooks) OnPacketIDExhausted(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketIDExhausted) {
|
||||
hook.OnPacketIDExhausted(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message. This method
|
||||
// differs from OnWillSent in that it allows you to modify the LWT message before it is
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
@@ -444,7 +480,7 @@ func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
@@ -453,7 +489,7 @@ func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
@@ -462,7 +498,7 @@ func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
@@ -472,7 +508,7 @@ func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
@@ -492,7 +528,7 @@ func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
@@ -512,7 +548,7 @@ func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
@@ -532,7 +568,7 @@ func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
@@ -551,7 +587,7 @@ func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
@@ -573,7 +609,7 @@ func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
@@ -589,7 +625,7 @@ func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.internal {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
@@ -600,19 +636,6 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired
|
||||
// inflight messages. Expiry should be the time after which the message is no longer
|
||||
// valid (usually some time in the past). A message has expired if it's created time
|
||||
// is older than time.Now() minus Inflight TTL. This method can be used to expire
|
||||
// old inflight messages in a persistent store which doesnt support per-item TTL.
|
||||
func (h *Hooks) OnExpireInflights(cl *Client, expiry int64) {
|
||||
for _, hook := range h.internal {
|
||||
if hook.Provides(OnExpireInflights) {
|
||||
hook.OnExpireInflights(cl, expiry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
@@ -728,6 +751,9 @@ func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, err
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
|
||||
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
|
||||
|
||||
@@ -740,6 +766,9 @@ func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPacketIDExhausted is called when the client runs out of unused packet ids to assign to a packet.
|
||||
func (h *HookBase) OnPacketIDExhausted(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message.
|
||||
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
return will, nil
|
||||
@@ -754,9 +783,6 @@ func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// OnExpireInflights is called when the server issues a clear request for expired inflight messages.
|
||||
func (h *HookBase) OnExpireInflights(cl *Client, expiry int64) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
@@ -79,7 +80,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -182,6 +182,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
|
||||
@@ -198,11 +202,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
@@ -347,32 +355,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.Delete(m.ID, new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", m.ID).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
@@ -381,6 +370,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
@@ -169,6 +168,21 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -218,6 +232,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -332,6 +369,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -418,48 +470,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
@@ -85,7 +85,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -185,6 +184,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
@@ -201,12 +204,17 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
@@ -369,34 +377,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var v []storage.Message
|
||||
err := h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to read inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range v {
|
||||
if m.Created < expiry || m.Created == 0 {
|
||||
err := h.db.DeleteStruct(&storage.Message{ID: m.ID})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("client", cl.ID).Msg("failed to clear inflight data")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
}
|
||||
@@ -404,6 +391,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -211,6 +211,21 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -226,6 +241,29 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -340,6 +378,21 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
@@ -426,48 +479,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey, Created: time.Now().Unix() - 1})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey, Created: time.Now().Unix() - 20})
|
||||
require.NoError(t, err)
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var v []storage.Message
|
||||
err = h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Len(t, v, 1)
|
||||
require.Equal(t, "i1", v[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
@@ -82,7 +83,6 @@ func (h *Hook) Provides(b byte) bool {
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.OnExpireInflights,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
@@ -199,6 +199,10 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if cl.StopCause() == packets.ErrSessionTakenOver {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
@@ -215,11 +219,15 @@ func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []by
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Qos: reasonCodes[i],
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
@@ -363,37 +371,13 @@ func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
}
|
||||
}
|
||||
|
||||
// OnExpireInflights removes all inflight messages which have passed the
|
||||
// provided expiry time.
|
||||
func (h *Hook) OnExpireInflights(cl *mqtt.Client, expiry int64) {
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
}
|
||||
|
||||
if d.Created < expiry || d.Created == 0 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), d.ID).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
@@ -402,6 +386,11 @@ func (h *Hook) OnRetainedExpired(filter string) {
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
@@ -252,6 +253,22 @@ func TestOnClientExpired(t *testing.T) {
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -268,6 +285,28 @@ func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectSessionTakenOver(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
testClient := &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
testClient.Stop(packets.ErrSessionTakenOver)
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(testClient, nil, true)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -391,6 +430,22 @@ func TestOnRetainedExpired(t *testing.T) {
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
@@ -483,60 +538,6 @@ func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnExpireInflights(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
n := time.Now().Unix()
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1",
|
||||
&storage.Message{ID: "i1", T: storage.InflightKey, Created: n - 1},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2",
|
||||
&storage.Message{ID: "i2", T: storage.InflightKey, Created: n - 20},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3",
|
||||
&storage.Message{ID: "i3", T: storage.InflightKey},
|
||||
).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
|
||||
var r []storage.Message
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 1)
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
err = d.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
r = append(r, d)
|
||||
}
|
||||
require.Len(t, r, 1)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnExpireInflightsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnExpireInflights(client, time.Now().Unix()-10)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
@@ -116,6 +117,36 @@ func (d *Message) UnmarshalBinary(data []byte) error {
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// ToPacket converts a storage.Message to a standard packet.
|
||||
func (d *Message) ToPacket() packets.Packet {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: d.FixedHeader,
|
||||
PacketID: d.PacketID,
|
||||
TopicName: d.TopicName,
|
||||
Payload: d.Payload,
|
||||
Origin: d.Origin,
|
||||
Created: d.Created,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
}
|
||||
|
||||
// Return a deep copy of the packet data otherwise the slices will
|
||||
// continue pointing at the values from the storage packet.
|
||||
pk = pk.Copy(true)
|
||||
pk.FixedHeader.Dup = d.FixedHeader.Dup
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an mqtt subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t"`
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
@@ -103,6 +104,7 @@ var (
|
||||
ClientsMaximum: 7,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
PacketsReceived: 12,
|
||||
PacketsSent: 13,
|
||||
Retained: 15,
|
||||
@@ -110,7 +112,7 @@ var (
|
||||
InflightDropped: 17,
|
||||
},
|
||||
}
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
)
|
||||
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
@@ -192,3 +194,35 @@ func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SystemInfo{}, d)
|
||||
}
|
||||
|
||||
func TestMessageToPacket(t *testing.T) {
|
||||
d := messageStruct
|
||||
pk := d.ToPacket()
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: d.FixedHeader.Remaining,
|
||||
Type: d.FixedHeader.Type,
|
||||
Qos: d.FixedHeader.Qos,
|
||||
Dup: d.FixedHeader.Dup,
|
||||
Retain: d.FixedHeader.Retain,
|
||||
},
|
||||
Origin: d.Origin,
|
||||
TopicName: d.TopicName,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: d.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: d.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: d.Properties.MessageExpiryInterval,
|
||||
ContentType: d.Properties.ContentType,
|
||||
ResponseTopic: d.Properties.ResponseTopic,
|
||||
CorrelationData: d.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: d.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: d.Properties.TopicAlias,
|
||||
User: d.Properties.User,
|
||||
},
|
||||
PacketID: 100,
|
||||
Created: d.Created,
|
||||
}, pk)
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -26,6 +27,10 @@ type modifiedHookBase struct {
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
@@ -177,12 +182,20 @@ func TestHooksProvides(t *testing.T) {
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddAndLen(t *testing.T) {
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
@@ -223,14 +236,15 @@ func TestHooksNonReturns(t *testing.T) {
|
||||
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
|
||||
h.OnUnsubscribed(cl, packets.Packet{})
|
||||
h.OnPublished(cl, packets.Packet{})
|
||||
h.OnPublishDropped(cl, packets.Packet{})
|
||||
h.OnRetainMessage(cl, packets.Packet{}, 0)
|
||||
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
|
||||
h.OnQosComplete(cl, packets.Packet{})
|
||||
h.OnQosDropped(cl, packets.Packet{})
|
||||
h.OnPacketIDExhausted(cl, packets.Packet{})
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
h.OnExpireInflights(cl, time.Now().Unix()-1)
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
|
||||
25
inflight.go
25
inflight.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -57,6 +58,18 @@ func (i *Inflight) Len() int {
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
@@ -103,14 +116,14 @@ func (i *Inflight) Delete(id uint16) bool {
|
||||
}
|
||||
|
||||
// TakeRecieveQuota reduces the receive quota by 1.
|
||||
func (i *Inflight) TakeReceiveQuota() {
|
||||
func (i *Inflight) DecreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) > 0 {
|
||||
atomic.AddInt32(&i.receiveQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// TakeRecieveQuota increases the receive quota by 1.
|
||||
func (i *Inflight) ReturnReceiveQuota() {
|
||||
func (i *Inflight) IncreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
|
||||
atomic.AddInt32(&i.receiveQuota, 1)
|
||||
}
|
||||
@@ -122,15 +135,15 @@ func (i *Inflight) ResetReceiveQuota(n int32) {
|
||||
atomic.StoreInt32(&i.maximumReceiveQuota, n)
|
||||
}
|
||||
|
||||
// TakeSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) TakeSendQuota() {
|
||||
// DecreaseSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) DecreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) > 0 {
|
||||
atomic.AddInt32(&i.sendQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// ReturnSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) ReturnSendQuota() {
|
||||
// IncreaseSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) IncreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
|
||||
atomic.AddInt32(&i.sendQuota, 1)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
@@ -24,7 +25,7 @@ func TestInflightSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
@@ -33,7 +34,7 @@ func TestInflightGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
@@ -55,13 +56,23 @@ func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
@@ -94,12 +105,12 @@ func TestReceiveQuota(t *testing.T) {
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Return 1
|
||||
i.ReturnReceiveQuota()
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.ReturnReceiveQuota()
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
@@ -109,12 +120,12 @@ func TestReceiveQuota(t *testing.T) {
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Take 1
|
||||
i.TakeReceiveQuota()
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.TakeReceiveQuota()
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
@@ -136,12 +147,12 @@ func TestSendQuota(t *testing.T) {
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Return 1
|
||||
i.ReturnSendQuota()
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.ReturnSendQuota()
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
@@ -151,18 +162,18 @@ func TestSendQuota(t *testing.T) {
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Take 1
|
||||
i.TakeSendQuota()
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.TakeSendQuota()
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newClient()
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
@@ -106,28 +107,7 @@ func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info := &system.Info{
|
||||
Version: l.sysInfo.Version,
|
||||
Started: atomic.LoadInt64(&l.sysInfo.Started),
|
||||
Time: atomic.LoadInt64(&l.sysInfo.Time),
|
||||
Uptime: atomic.LoadInt64(&l.sysInfo.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&l.sysInfo.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&l.sysInfo.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&l.sysInfo.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&l.sysInfo.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&l.sysInfo.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&l.sysInfo.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&l.sysInfo.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&l.sysInfo.MessagesSent),
|
||||
InflightDropped: atomic.LoadInt64(&l.sysInfo.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&l.sysInfo.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&l.sysInfo.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&l.sysInfo.PacketsSent),
|
||||
Retained: atomic.LoadInt64(&l.sysInfo.Retained),
|
||||
Inflight: atomic.LoadInt64(&l.sysInfo.Inflight),
|
||||
MemoryAlloc: atomic.LoadInt64(&l.sysInfo.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&l.sysInfo.Threads),
|
||||
}
|
||||
info := *l.sysInfo.Clone()
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
88
listeners/net.go
Normal file
88
listeners/net.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
func NewNet(id string, listener net.Listener) *Net {
|
||||
return &Net{
|
||||
id: id,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Net) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Net) Address() string {
|
||||
return l.listener.Addr().String()
|
||||
}
|
||||
|
||||
// Protocol returns the network of the listener.
|
||||
func (l *Net) Protocol() string {
|
||||
return l.listener.Addr().Network()
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *Net) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Net) Close(closeClients CloseFn) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
err := l.listener.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
105
listeners/net_test.go
Normal file
105
listeners/net_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewNet(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.id)
|
||||
}
|
||||
|
||||
func TestNetID(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestNetAddress(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, n.Addr().String(), l.Address())
|
||||
}
|
||||
|
||||
func TestNetProtocol(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestNetInit(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetServeAndClose(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestNetEstablishThenEnd(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
98
listeners/unixsock.go
Normal file
98
listeners/unixsock.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(id, address string) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: id,
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
96
listeners/unixsock_test.go
Normal file
96
listeners/unixsock_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -136,25 +138,35 @@ type wsConn struct {
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (n int, err error) {
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r.Read(p)
|
||||
var n, br int
|
||||
for {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (n int, err error) {
|
||||
err = ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
@@ -375,7 +376,7 @@ func TestEncodeUint16(t *testing.T) {
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result)
|
||||
|
||||
result = encodeUint16(65535)
|
||||
result = encodeUint16(math.MaxUint16)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// Code contains a reason code and reason string for a response.
|
||||
@@ -20,7 +21,7 @@ func (c Code) Error() string {
|
||||
}
|
||||
|
||||
var (
|
||||
// QosCodes indicicates the reason codes for each Qos byte.
|
||||
// QosCodes indicates the reason codes for each Qos byte.
|
||||
QosCodes = map[byte]Code{
|
||||
0: CodeGrantedQos0,
|
||||
1: CodeGrantedQos1,
|
||||
@@ -112,15 +113,35 @@ var (
|
||||
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
|
||||
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
|
||||
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
|
||||
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
|
||||
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
|
||||
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
|
||||
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
|
||||
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
|
||||
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
|
||||
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"}
|
||||
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -14,22 +16,23 @@ import (
|
||||
|
||||
// All of the valid packet types and their packet identifier.
|
||||
const (
|
||||
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
Auth // 15
|
||||
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
|
||||
Connect // 1
|
||||
Connack // 2
|
||||
Publish // 3
|
||||
Puback // 4
|
||||
Pubrec // 5
|
||||
Pubrel // 6
|
||||
Pubcomp // 7
|
||||
Subscribe // 8
|
||||
Suback // 9
|
||||
Unsubscribe // 10
|
||||
Unsuback // 11
|
||||
Pingreq // 12
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
Auth // 15
|
||||
WillProperties byte = 99 // Special byte for validating Will Properties.
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -207,7 +210,10 @@ func (pk *Packet) Copy(allowTransfer bool) Packet {
|
||||
Created: pk.Created,
|
||||
Expiry: pk.Expiry,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID, // ... ? Packet ID must not be transferred (in this manner)
|
||||
}
|
||||
|
||||
if allowTransfer {
|
||||
p.PacketID = pk.PacketID
|
||||
}
|
||||
|
||||
if len(pk.Connect.ProtocolName) > 0 {
|
||||
@@ -308,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Properties).Encode(pk, pb, 0)
|
||||
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -317,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
if pk.Connect.WillFlag {
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
|
||||
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -378,7 +384,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Connect.ClientIdentifier, offset, err = decodeString(buf, offset) //[MQTT-3.1.3-1] [MQTT-3.1.3-2] [MQTT-3.1.3-3] [MQTT-3.1.3-4]
|
||||
@@ -388,11 +394,11 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
|
||||
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
|
||||
if pk.ProtocolVersion == 5 {
|
||||
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
|
||||
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
|
||||
if err != nil {
|
||||
return ErrMalformedWillProperties
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Connect.WillTopic, offset, err = decodeString(buf, offset)
|
||||
@@ -407,6 +413,10 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
}
|
||||
|
||||
if pk.Connect.UsernameFlag { // [MQTT-3.1.3-12]
|
||||
if offset >= len(buf) { // we are at the end of the packet
|
||||
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
|
||||
}
|
||||
|
||||
pk.Connect.Username, offset, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return ErrMalformedUsername
|
||||
@@ -438,18 +448,14 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
return ErrProtocolViolationReservedBit // [MQTT-3.1.2-3]
|
||||
}
|
||||
|
||||
if len(pk.Connect.Password) > 65535 {
|
||||
if len(pk.Connect.Password) > math.MaxUint16 {
|
||||
return ErrProtocolViolationPasswordTooLong
|
||||
}
|
||||
|
||||
if len(pk.Connect.Username) > 65535 {
|
||||
if len(pk.Connect.Username) > math.MaxUint16 {
|
||||
return ErrProtocolViolationUsernameTooLong
|
||||
}
|
||||
|
||||
if pk.Connect.UsernameFlag && len(pk.Connect.Username) == 0 {
|
||||
return ErrProtocolViolationFlagNoUsername // [MQTT-3.1.2-17]
|
||||
}
|
||||
|
||||
if !pk.Connect.UsernameFlag && len(pk.Connect.Username) > 0 {
|
||||
return ErrProtocolViolationUsernameNoFlag // [MQTT-3.1.2-16]
|
||||
}
|
||||
@@ -462,7 +468,7 @@ func (pk *Packet) ConnectValidate() Code {
|
||||
return ErrProtocolViolationPasswordNoFlag // [MQTT-3.1.2-18]
|
||||
}
|
||||
|
||||
if len(pk.Connect.ClientIdentifier) > 65535 {
|
||||
if len(pk.Connect.ClientIdentifier) > math.MaxUint16 {
|
||||
return ErrClientIdentifierNotValid
|
||||
}
|
||||
|
||||
@@ -491,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -534,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -603,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -639,7 +645,7 @@ func (pk *Packet) PublishDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.Payload = buf[offset:]
|
||||
@@ -687,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
}
|
||||
@@ -828,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -856,7 +862,7 @@ func (pk *Packet) SubackDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
pk.ReasonCodes = buf[offset:]
|
||||
@@ -885,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -913,7 +919,7 @@ func (pk *Packet) SubscribeDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
var filter string
|
||||
@@ -980,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -1009,7 +1015,7 @@ func (pk *Packet) UnsubackDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
|
||||
offset += n + 1
|
||||
offset += n
|
||||
|
||||
pk.ReasonCodes = buf[offset:]
|
||||
}
|
||||
@@ -1033,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -1061,7 +1067,7 @@ func (pk *Packet) UnsubscribeDecode(buf []byte) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedProperties)
|
||||
}
|
||||
offset += n + 1
|
||||
offset += n
|
||||
}
|
||||
|
||||
var filter string
|
||||
@@ -1096,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
@@ -463,6 +464,9 @@ func TestCopy(t *testing.T) {
|
||||
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
|
||||
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
|
||||
|
||||
pkcc := tt.Packet.Copy(false)
|
||||
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
@@ -41,11 +42,11 @@ const (
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1},
|
||||
PropContentType: {Publish: 1},
|
||||
PropResponseTopic: {Publish: 1},
|
||||
PropCorrelationData: {Publish: 1},
|
||||
PropPayloadFormat: {Publish: 1, WillProperties: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
|
||||
PropContentType: {Publish: 1, WillProperties: 1},
|
||||
PropResponseTopic: {Publish: 1, WillProperties: 1},
|
||||
PropCorrelationData: {Publish: 1, WillProperties: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
@@ -53,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {Connect: 1},
|
||||
PropWillDelayInterval: {WillProperties: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
@@ -63,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
@@ -193,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
pkt := pk.FixedHeader.Type
|
||||
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
@@ -216,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
@@ -276,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
@@ -288,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
@@ -321,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
@@ -330,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
@@ -360,18 +359,19 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, _, err = DecodeLength(b)
|
||||
var bu int
|
||||
n, bu, err = DecodeLength(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return n, nil
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
bt := b.Bytes()
|
||||
@@ -379,11 +379,11 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
for offset := 0; offset < n; {
|
||||
k, offset, err = decodeByte(bt, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pk]; !ok {
|
||||
return n, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
|
||||
if _, ok := validPacketProperties[k][pkt]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
@@ -405,7 +405,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
|
||||
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
|
||||
offset += bu
|
||||
@@ -451,7 +451,7 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
var k, v string
|
||||
k, offset, err = decodeString(bt, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
v, offset, err = decodeString(bt, offset)
|
||||
p.User = append(p.User, UserProperty{Key: k, Val: v})
|
||||
@@ -469,9 +469,9 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return n, err
|
||||
return n + bu, err
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
@@ -201,14 +202,14 @@ func init() {
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
@@ -218,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
@@ -231,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
|
||||
pr.p.Encode(Reserved, Mods{}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
@@ -239,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
@@ -249,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
|
||||
props := new(Properties)
|
||||
n, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 172, n)
|
||||
require.Equal(t, 172+2, n)
|
||||
require.EqualValues(t, propertiesStruct, *props)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// TPacketCase contains data for cross-checking the encoding and decoding
|
||||
@@ -70,7 +71,7 @@ const (
|
||||
TConnectInvalidWillFlagNoPayload
|
||||
TConnectInvalidWillFlagQosOutOfRange
|
||||
TConnectInvalidWillSurplusRetain
|
||||
TConnectNotCleanNoClientID
|
||||
TConnectZeroByteUsername
|
||||
TConnectSpecInvalidUTF8D800
|
||||
TConnectSpecInvalidUTF8DFFF
|
||||
TConnectSpecInvalidUTF80000
|
||||
@@ -88,6 +89,7 @@ const (
|
||||
TConnackServerUnavailable
|
||||
TConnackBadUsernamePassword
|
||||
TConnackBadUsernamePasswordNoSession
|
||||
TConnackMqtt5BadUsernamePasswordNoSession
|
||||
TConnackNotAuthorised
|
||||
TConnackMalSessionPresent
|
||||
TConnackMalReturnCode
|
||||
@@ -248,26 +250,26 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "mqtt v3.1.1",
|
||||
Primary: true,
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 16, // Fixed header
|
||||
Connect << 4, 15, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 60, // Keepalive
|
||||
0, 4, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', '3', // Client ID "zen"
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connect,
|
||||
Remaining: 16,
|
||||
Remaining: 15,
|
||||
},
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: false,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "zen3",
|
||||
ClientIdentifier: "zen",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -424,9 +426,9 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Connect << 4, 28, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
194, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
4, // Protocol Version
|
||||
0 | 1<<6 | 1<<7, // Packet Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 5, // Username MSB+LSB
|
||||
@@ -442,7 +444,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: true,
|
||||
Clean: false,
|
||||
Keepalive: 20,
|
||||
ClientIdentifier: "zen",
|
||||
UsernameFlag: true,
|
||||
@@ -496,6 +498,43 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectZeroByteUsername,
|
||||
Desc: "username flag but 0 byte username",
|
||||
Group: "decode",
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 23, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
5, // Protocol Version
|
||||
130, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
5, // length
|
||||
17, 0, 0, 0, 120, // Session Expiry Interval (17)
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 0, // Username MSB+LSB
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connect,
|
||||
Remaining: 23,
|
||||
},
|
||||
ProtocolVersion: 5,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
Clean: true,
|
||||
Keepalive: 30,
|
||||
ClientIdentifier: "zen",
|
||||
Username: []byte{},
|
||||
UsernameFlag: true,
|
||||
},
|
||||
Properties: Properties{
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// Fail States
|
||||
{
|
||||
@@ -622,6 +661,24 @@ var TPacketData = map[byte]TPacketCases{
|
||||
'm', 'o', 'c',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TConnectInvalidFlagNoUsername,
|
||||
Desc: "username flag with no username bytes",
|
||||
Group: "decode",
|
||||
FailFirst: ErrProtocolViolationFlagNoUsername,
|
||||
RawBytes: []byte{
|
||||
Connect << 4, 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
5, // Protocol Version
|
||||
130, // Flags
|
||||
0, 20, // Keepalive
|
||||
0,
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectMalPassword,
|
||||
Desc: "malformed password",
|
||||
@@ -782,20 +839,6 @@ var TPacketData = map[byte]TPacketCases{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectInvalidFlagNoUsername,
|
||||
Desc: "has username flag but no username",
|
||||
Group: "validate",
|
||||
Expect: ErrProtocolViolationFlagNoUsername,
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{Type: Connect},
|
||||
ProtocolVersion: 4,
|
||||
Connect: ConnectParams{
|
||||
ProtocolName: []byte("MQTT"),
|
||||
UsernameFlag: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnectInvalidUsernameNoFlag,
|
||||
Desc: "has username but no flag",
|
||||
@@ -1315,10 +1358,28 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Desc: "bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 2, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0, // No session present
|
||||
Err3NotAuthorized.Code, // use v3 remapping
|
||||
},
|
||||
Packet: &Packet{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
},
|
||||
ReasonCode: Err3NotAuthorized.Code,
|
||||
},
|
||||
},
|
||||
{
|
||||
Case: TConnackMqtt5BadUsernamePasswordNoSession,
|
||||
Desc: "mqtt5 bad username or password no session",
|
||||
RawBytes: []byte{
|
||||
Connack << 4, 3, // fixed header
|
||||
0, // No session present
|
||||
ErrBadUsernameOrPassword.Code,
|
||||
0,
|
||||
},
|
||||
Packet: &Packet{
|
||||
ProtocolVersion: 5,
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Connack,
|
||||
Remaining: 2,
|
||||
@@ -1326,6 +1387,7 @@ var TPacketData = map[byte]TPacketCases{
|
||||
ReasonCode: ErrBadUsernameOrPassword.Code,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Case: TConnackNotAuthorised,
|
||||
Desc: "not authorised",
|
||||
@@ -1803,13 +1865,10 @@ var TPacketData = map[byte]TPacketCases{
|
||||
Case: TPublishRetainMqtt5,
|
||||
Desc: "retain mqtt5",
|
||||
RawBytes: []byte{
|
||||
Publish<<4 | 1<<0, 35, // Fixed header
|
||||
Publish<<4 | 1<<0, 19, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
16, // properties length
|
||||
38, // User Properties (38)
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
0, // properties length
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
||||
},
|
||||
Packet: &Packet{
|
||||
@@ -1817,18 +1876,11 @@ var TPacketData = map[byte]TPacketCases{
|
||||
FixedHeader: FixedHeader{
|
||||
Type: Publish,
|
||||
Retain: true,
|
||||
Remaining: 35,
|
||||
Remaining: 19,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
},
|
||||
},
|
||||
Payload: []byte("hello mochi"),
|
||||
TopicName: "a/b/c",
|
||||
Properties: Properties{},
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
|
||||
345
server.go
345
server.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
// package mqtt provides a high performance, fully compliant MQTT v5 broker server with v3.1.1 backward compatibility.
|
||||
package mqtt
|
||||
|
||||
@@ -25,10 +26,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "2.0.0" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
defaultFanPoolSize uint64 = 64 // the number of concurrent workers in the pool
|
||||
defaultFanPoolQueueSize uint64 = 32 * 128 // the capacity of each worker queue
|
||||
Version = "2.2.8" // the current server version.
|
||||
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -46,6 +45,7 @@ var (
|
||||
SharedSubAvailable: 1, // shared subscriptions are available
|
||||
ServerKeepAlive: 10, // default keepalive for clients
|
||||
MinimumProtocolVersion: 3, // minimum supported mqtt version (3.0.0)
|
||||
MaximumClientWritesPending: 1024 * 8, // maximum number of pending message writes for a client
|
||||
}
|
||||
|
||||
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists.
|
||||
@@ -55,18 +55,20 @@ var (
|
||||
// Capabilities indicates the capabilities and features provided by the server.
|
||||
type Capabilities struct {
|
||||
MaximumMessageExpiryInterval int64
|
||||
MaximumClientWritesPending int32
|
||||
MaximumSessionExpiryInterval uint32
|
||||
MaximumPacketSize uint32
|
||||
maximumPacketID uint32 // unexported, used for testing only
|
||||
ReceiveMaximum uint16
|
||||
TopicAliasMaximum uint16
|
||||
ServerKeepAlive uint16
|
||||
SharedSubAvailable byte
|
||||
MinimumProtocolVersion byte
|
||||
Compatibilities Compatibilities
|
||||
MaximumQos byte
|
||||
RetainAvailable byte
|
||||
WildcardSubAvailable byte
|
||||
SubIDAvailable byte
|
||||
SharedSubAvailable byte
|
||||
MinimumProtocolVersion byte
|
||||
}
|
||||
|
||||
// Compatibilities provides flags for using compatibility modes.
|
||||
@@ -79,9 +81,17 @@ type Compatibilities struct {
|
||||
|
||||
// Options contains configurable options for the server.
|
||||
type Options struct {
|
||||
// Capabilities defines the server features and behaviour.
|
||||
// Capabilities defines the server features and behaviour. If you only wish to modify
|
||||
// several of these values, set them explicitly - e.g.
|
||||
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
|
||||
Capabilities *Capabilities
|
||||
|
||||
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
|
||||
ClientNetWriteBufferSize int
|
||||
|
||||
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
|
||||
ClientNetReadBufferSize int
|
||||
|
||||
// Logger specifies a custom configured implementation of zerolog to override
|
||||
// the servers default logger configuration. If you wish to change the log level,
|
||||
// of the default logger, you can do so by setting
|
||||
@@ -90,16 +100,6 @@ type Options struct {
|
||||
// server.Log = &l
|
||||
Logger *zerolog.Logger
|
||||
|
||||
// FanPoolSize is the number of individual workers and queues to initialize.
|
||||
// Bigger is not necessarily better, and you should rely on defaults unless
|
||||
// you have know what you are doing.
|
||||
FanPoolSize uint64
|
||||
|
||||
// FanPoolQueueSize is the size of the queue per worker. Increase this value
|
||||
// accordingly if you anticipate having intermittent but massive numbers of
|
||||
// messages. Cluster support is roadmapped.
|
||||
FanPoolQueueSize uint64
|
||||
|
||||
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
|
||||
SysTopicResendInterval int64
|
||||
}
|
||||
@@ -112,7 +112,6 @@ type Server struct {
|
||||
Clients *Clients // clients known to the broker
|
||||
Topics *TopicsIndex // an index of topic filter subscriptions and retained messages
|
||||
Info *system.Info // values about the server commonly known as $SYS topics
|
||||
fanpool *FanPool // a fixed size worker pool for processing inbound and outbound messages
|
||||
loop *loop // loop contains tickers for the system event loop
|
||||
done chan bool // indicate that the server is ending
|
||||
Log *zerolog.Logger // minimal no-alloc logger
|
||||
@@ -131,10 +130,10 @@ type loop struct {
|
||||
|
||||
// ops contains server values which can be propagated to other structs.
|
||||
type ops struct {
|
||||
capabilities *Capabilities // a pointer to the server capabilities, for referencing in clients
|
||||
info *system.Info // pointers to server system info
|
||||
hooks *Hooks // pointer to the server hooks
|
||||
log *zerolog.Logger // a structured logger for the client
|
||||
options *Options // a pointer to the server options and capabilities, for referencing in clients
|
||||
info *system.Info // pointers to server system info
|
||||
hooks *Hooks // pointer to the server hooks
|
||||
log *zerolog.Logger // a structured logger for the client
|
||||
}
|
||||
|
||||
// New returns a new instance of mochi mqtt broker. Optional parameters
|
||||
@@ -164,8 +163,7 @@ func New(opts *Options) *Server {
|
||||
Version: Version,
|
||||
Started: time.Now().Unix(),
|
||||
},
|
||||
fanpool: NewFanPool(opts.FanPoolSize, opts.FanPoolQueueSize),
|
||||
Log: opts.Logger,
|
||||
Log: opts.Logger,
|
||||
hooks: &Hooks{
|
||||
Log: opts.Logger,
|
||||
},
|
||||
@@ -180,16 +178,18 @@ func (o *Options) ensureDefaults() {
|
||||
o.Capabilities = DefaultServerCapabilities
|
||||
}
|
||||
|
||||
o.Capabilities.maximumPacketID = math.MaxUint16 // spec maximum is 65535
|
||||
|
||||
if o.SysTopicResendInterval == 0 {
|
||||
o.SysTopicResendInterval = defaultSysTopicInterval
|
||||
}
|
||||
|
||||
if o.FanPoolSize == 0 {
|
||||
o.FanPoolSize = defaultFanPoolSize
|
||||
if o.ClientNetWriteBufferSize == 0 {
|
||||
o.ClientNetWriteBufferSize = 1024 * 2
|
||||
}
|
||||
|
||||
if o.FanPoolQueueSize < 1 {
|
||||
o.FanPoolQueueSize = defaultFanPoolQueueSize
|
||||
if o.ClientNetReadBufferSize == 0 {
|
||||
o.ClientNetReadBufferSize = 1024 * 2
|
||||
}
|
||||
|
||||
if o.Logger == nil {
|
||||
@@ -198,6 +198,33 @@ func (o *Options) ensureDefaults() {
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient returns a new Client instance, populated with all the required values and
|
||||
// references to be used with the server. If you are using this client to directly publish
|
||||
// messages from the embedding application, set the inline flag to true to bypass ACL and
|
||||
// topic validation checks.
|
||||
func (s *Server) NewClient(c net.Conn, listener string, id string, inline bool) *Client {
|
||||
cl := newClient(c, &ops{ // [MQTT-3.1.2-6] implicit
|
||||
options: s.Options,
|
||||
info: s.Info,
|
||||
hooks: s.hooks,
|
||||
log: s.Log,
|
||||
})
|
||||
|
||||
cl.ID = id
|
||||
cl.Net.Listener = listener
|
||||
|
||||
if inline { // inline clients bypass acl and some validity checks.
|
||||
cl.Net.Inline = true
|
||||
// By default we don't want to restrict developer publishes,
|
||||
// but if you do, reset this after creating inline client.
|
||||
cl.State.Inflight.ResetReceiveQuota(math.MaxInt32)
|
||||
} else {
|
||||
go cl.WriteLoop() // can only write to real clients
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// AddHook attaches a new Hook to the server. Ideally, this should be called
|
||||
// before the server is started with s.Serve().
|
||||
func (s *Server) AddHook(hook Hook, config any) error {
|
||||
@@ -280,27 +307,21 @@ func (s *Server) eventLoop() {
|
||||
}
|
||||
|
||||
// EstablishConnection establishes a new client when a listener accepts a new connection.
|
||||
func (s *Server) EstablishConnection(lid string, c net.Conn) error {
|
||||
cl := NewClient(c, &ops{ // [MQTT-3.1.2-6] implicit
|
||||
capabilities: s.Options.Capabilities,
|
||||
info: s.Info,
|
||||
hooks: s.hooks,
|
||||
log: s.Log,
|
||||
})
|
||||
|
||||
return s.attachClient(cl, lid)
|
||||
func (s *Server) EstablishConnection(listener string, c net.Conn) error {
|
||||
cl := s.NewClient(c, listener, "", false)
|
||||
return s.attachClient(cl, listener)
|
||||
}
|
||||
|
||||
// attachClient validates an incoming client connection and if viable, attaches the client
|
||||
// to the server, performs session housekeeping, and reads incoming packets.
|
||||
func (s *Server) attachClient(cl *Client, lid string) error {
|
||||
func (s *Server) attachClient(cl *Client, listener string) error {
|
||||
defer cl.Stop(nil)
|
||||
pk, err := s.readConnectionPacket(cl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read connection: %w", err)
|
||||
}
|
||||
|
||||
cl.ParseConnect(lid, pk)
|
||||
cl.ParseConnect(listener, pk)
|
||||
code := s.validateConnect(cl, pk) // [MQTT-3.1.4-1] [MQTT-3.1.4-2]
|
||||
if code != packets.CodeSuccess {
|
||||
if err := s.sendConnack(cl, code, false); err != nil {
|
||||
@@ -346,18 +367,17 @@ func (s *Server) attachClient(cl *Client, lid string) error {
|
||||
if err != nil {
|
||||
s.sendLWT(cl)
|
||||
cl.Stop(err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
} else {
|
||||
cl.Properties.Will = Will{} // [MQTT-3.14.4-3] [MQTT-3.1.2-10]
|
||||
}
|
||||
|
||||
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", lid).Msg("client disconnected")
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryIntervalFlag && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.Log.Debug().Str("client", cl.ID).Err(err).Str("remote", cl.Net.Remote).Str("listener", listener).Msg("client disconnected")
|
||||
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
|
||||
s.hooks.OnDisconnect(cl, err, expire)
|
||||
if expire {
|
||||
s.unsubscribeClient(cl)
|
||||
|
||||
if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
|
||||
cl.ClearInflights(math.MaxInt64, 0)
|
||||
s.UnsubscribeClient(cl)
|
||||
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
|
||||
}
|
||||
|
||||
@@ -431,25 +451,40 @@ func (s *Server) validateConnect(cl *Client, pk packets.Packet) packets.Code {
|
||||
// session is abandoned.
|
||||
func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
|
||||
if existing, ok := s.Clients.Get(pk.Connect.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
defer existing.Unlock()
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && cl.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.unsubscribeClient(existing)
|
||||
s.DisconnectClient(existing, packets.ErrSessionTakenOver) // [MQTT-3.1.4-3]
|
||||
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
return false // [MQTT-3.2.2-3]
|
||||
atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
|
||||
return false // [MQTT-3.2.2-3]
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&existing.State.isTakenOver, 1)
|
||||
if existing.State.Inflight.Len() > 0 {
|
||||
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
|
||||
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
}
|
||||
}
|
||||
|
||||
cl.State.Inflight = existing.State.Inflight // [MQTT-3.1.2-5]
|
||||
for _, sub := range existing.State.Subscriptions.GetAll() {
|
||||
existed := !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||
if !existed {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, 1)
|
||||
}
|
||||
cl.State.Subscriptions.Add(sub.Filter, sub)
|
||||
s.publishRetainedToClient(cl, sub, existed)
|
||||
}
|
||||
|
||||
// Clean the state of the existing client to prevent sequential take-overs
|
||||
// from increasing memory usage by inflights + subs * client-id.
|
||||
s.UnsubscribeClient(existing)
|
||||
existing.ClearInflights(math.MaxInt64, 0)
|
||||
s.Log.Debug().Str("client", cl.ID).
|
||||
Str("old_remote", existing.Net.Remote).
|
||||
Str("new_remote", cl.Net.Remote).
|
||||
Msg("session taken over")
|
||||
|
||||
return true // [MQTT-3.2.2-3]
|
||||
}
|
||||
|
||||
@@ -469,6 +504,12 @@ func (s *Server) sendConnack(cl *Client, reason packets.Code, present bool) erro
|
||||
}
|
||||
|
||||
if reason.Code >= packets.ErrUnspecifiedError.Code {
|
||||
if cl.Properties.ProtocolVersion < 5 {
|
||||
if v3reason, ok := packets.V5CodesToV3[reason]; ok { // NB v3 3.2.2.3 Connack return codes
|
||||
reason = v3reason
|
||||
}
|
||||
}
|
||||
|
||||
properties.ReasonString = reason.Reason
|
||||
ack := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
@@ -568,7 +609,7 @@ func (s *Server) processPacket(cl *Client, pk packets.Packet) error {
|
||||
if ok := cl.State.Inflight.Delete(next.PacketID); ok {
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
}
|
||||
cl.State.Inflight.TakeSendQuota()
|
||||
cl.State.Inflight.DecreaseSendQuota()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -591,6 +632,24 @@ func (s *Server) processPingreq(cl *Client, _ packets.Packet) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Publish publishes a publish packet into the broker as if it were sent from the speicfied client.
|
||||
// This is a convenience function which wraps InjectPacket. As such, this method can publish packets
|
||||
// to any topic (including $SYS) and bypass ACL checks. The qos byte is used for limiting the
|
||||
// outbound qos (mqtt v5) rather than issuing to the broker (we assume qos 2 complete).
|
||||
func (s *Server) Publish(topic string, payload []byte, retain bool, qos byte) error {
|
||||
cl := s.NewClient(nil, "local", "inline", true)
|
||||
return s.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Qos: qos,
|
||||
Retain: retain,
|
||||
},
|
||||
TopicName: topic,
|
||||
Payload: payload,
|
||||
PacketID: uint16(qos), // we never process the inbound qos, but we need a packet id for validity checks.
|
||||
})
|
||||
}
|
||||
|
||||
// InjectPacket injects a packet into the broker as if it were sent from the specified client.
|
||||
// InlineClients using this method can publish packets to any topic (including $SYS) and bypass ACL checks.
|
||||
func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
||||
@@ -611,7 +670,7 @@ func (s *Server) InjectPacket(cl *Client, pk packets.Packet) error {
|
||||
|
||||
// processPublish processes a Publish packet.
|
||||
func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
if !IsValidFilter(pk.TopicName, true) && !cl.Net.Inline {
|
||||
if !cl.Net.Inline && !IsValidFilter(pk.TopicName, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -619,20 +678,22 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
return s.DisconnectClient(cl, packets.ErrReceiveMaximum) // ~[MQTT-3.3.4-7] ~[MQTT-3.3.4-8]
|
||||
}
|
||||
|
||||
if !s.hooks.OnACLCheck(cl, pk.TopicName, true) && !cl.Net.Inline {
|
||||
if !cl.Net.Inline && !s.hooks.OnACLCheck(cl, pk.TopicName, true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
pk.Origin = cl.ID
|
||||
pk.Created = time.Now().Unix()
|
||||
|
||||
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
|
||||
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
|
||||
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
if !cl.Net.Inline {
|
||||
if pki, ok := cl.State.Inflight.Get(pk.PacketID); ok {
|
||||
if pki.FixedHeader.Type == packets.Pubrec { // [MQTT-4.3.3-10]
|
||||
ack := s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.ErrPacketIdentifierInUse)
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -655,14 +716,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Qos == 0 {
|
||||
s.fanpool.Enqueue(cl.ID, func() {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.publishToSubscribers(pk)
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.State.Inflight.TakeReceiveQuota()
|
||||
cl.State.Inflight.DecreaseReceiveQuota()
|
||||
ack := s.buildAck(pk.PacketID, packets.Puback, 0, pk.Properties, packets.QosCodes[pk.FixedHeader.Qos]) // [MQTT-4.3.2-4]
|
||||
if pk.FixedHeader.Qos == 2 {
|
||||
ack = s.buildAck(pk.PacketID, packets.Pubrec, 0, pk.Properties, packets.CodeSuccess) // [MQTT-3.3.4-1] [MQTT-4.3.3-8]
|
||||
@@ -670,6 +729,7 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
|
||||
if ok := cl.State.Inflight.Set(ack); ok {
|
||||
atomic.AddInt64(&s.Info.Inflight, 1)
|
||||
s.hooks.OnQosPublish(cl, ack, ack.Created, 0)
|
||||
}
|
||||
|
||||
err := cl.WritePacket(ack)
|
||||
@@ -681,15 +741,12 @@ func (s *Server) processPublish(cl *Client, pk packets.Packet) error {
|
||||
if ok := cl.State.Inflight.Delete(ack.PacketID); ok {
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
}
|
||||
cl.State.Inflight.ReturnReceiveQuota()
|
||||
s.hooks.OnQosComplete(cl, pk)
|
||||
cl.State.Inflight.IncreaseReceiveQuota()
|
||||
s.hooks.OnQosComplete(cl, ack)
|
||||
}
|
||||
|
||||
s.fanpool.Enqueue(cl.ID, func() {
|
||||
s.publishToSubscribers(pk)
|
||||
})
|
||||
|
||||
s.hooks.OnPublish(cl, pk)
|
||||
s.publishToSubscribers(pk)
|
||||
s.hooks.OnPublished(cl, pk)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -733,13 +790,13 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (out packets.Packet, err error) {
|
||||
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
|
||||
if sub.NoLocal && pk.Origin == cl.ID {
|
||||
return pk, nil // [MQTT-3.8.3-3]
|
||||
}
|
||||
|
||||
out = pk.Copy(false)
|
||||
if !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
|
||||
out := pk.Copy(false)
|
||||
if cl.Properties.ProtocolVersion == 5 && !sub.RetainAsPublished { // ![MQTT-3.3.1-13]
|
||||
out.FixedHeader.Retain = false // [MQTT-3.3.1-12]
|
||||
}
|
||||
|
||||
@@ -769,6 +826,7 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
if out.FixedHeader.Qos > 0 {
|
||||
i, err := cl.NextPacketID() // [MQTT-4.3.2-1] [MQTT-4.3.3-1]
|
||||
if err != nil {
|
||||
s.hooks.OnPacketIDExhausted(cl, pk)
|
||||
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Msg("packet ids exhausted")
|
||||
return out, packets.ErrQuotaExceeded
|
||||
}
|
||||
@@ -779,22 +837,32 @@ func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packet
|
||||
if ok := cl.State.Inflight.Set(out); ok { // [MQTT-4.3.2-3] [MQTT-4.3.3-3]
|
||||
atomic.AddInt64(&s.Info.Inflight, 1)
|
||||
s.hooks.OnQosPublish(cl, out, out.Created, 0)
|
||||
cl.State.Inflight.DecreaseSendQuota()
|
||||
}
|
||||
|
||||
if sentQuota == 0 && atomic.LoadInt32(&cl.State.Inflight.maximumSendQuota) > 0 {
|
||||
out.Expiry = -1
|
||||
cl.State.Inflight.Set(out)
|
||||
return pk, nil
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
if cl.Net.conn == nil || atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return pk, packets.CodeDisconnect
|
||||
if cl.Net.Conn == nil || cl.Closed() {
|
||||
return out, packets.CodeDisconnect
|
||||
}
|
||||
|
||||
cl.State.Inflight.TakeSendQuota()
|
||||
select {
|
||||
case cl.State.outbound <- &out:
|
||||
atomic.AddInt32(&cl.State.outboundQty, 1)
|
||||
default:
|
||||
atomic.AddInt64(&s.Info.MessagesDropped, 1)
|
||||
cl.ops.hooks.OnPublishDropped(cl, pk)
|
||||
cl.State.Inflight.Delete(out.PacketID) // packet was dropped due to irregular circumstances, so rollback inflight.
|
||||
cl.State.Inflight.IncreaseSendQuota()
|
||||
return out, packets.ErrPendingClientWritesExceeded
|
||||
}
|
||||
|
||||
return out, cl.WritePacket(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, existed bool) {
|
||||
@@ -809,7 +877,7 @@ func (s *Server) publishRetainedToClient(cl *Client, sub packets.Subscription, e
|
||||
for _, pkv := range s.Topics.Messages(sub.Filter) { // [MQTT-3.8.4-4]
|
||||
_, err := s.publishToClient(cl, sub, pkv)
|
||||
if err != nil {
|
||||
s.Log.Warn().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
|
||||
s.Log.Debug().Err(err).Str("client", cl.ID).Str("listener", cl.Net.Listener).Interface("packet", pkv).Msg("failed to publish retained message")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -843,7 +911,7 @@ func (s *Server) processPuback(cl *Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.2-5]
|
||||
cl.State.Inflight.ReturnSendQuota()
|
||||
cl.State.Inflight.IncreaseSendQuota()
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
s.hooks.OnQosComplete(cl, pk)
|
||||
}
|
||||
@@ -866,7 +934,7 @@ func (s *Server) processPubrec(cl *Client, pk packets.Packet) error {
|
||||
}
|
||||
|
||||
ack := s.buildAck(pk.PacketID, packets.Pubrel, 1, pk.Properties, packets.CodeSuccess) // [MQTT-4.3.3-4] ![MQTT-4.3.3-6]
|
||||
cl.State.Inflight.TakeReceiveQuota() // -1 RECV QUOTA
|
||||
cl.State.Inflight.DecreaseReceiveQuota() // -1 RECV QUOTA
|
||||
cl.State.Inflight.Set(ack) // [MQTT-4.3.3-5]
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
@@ -893,8 +961,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
|
||||
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
|
||||
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
|
||||
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
|
||||
if ok := cl.State.Inflight.Delete(pk.PacketID); ok { // [MQTT-4.3.3-12]
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
s.hooks.OnQosComplete(cl, pk)
|
||||
@@ -906,8 +974,8 @@ func (s *Server) processPubrel(cl *Client, pk packets.Packet) error {
|
||||
// processPubcomp processes a Pubcomp packet, denoting completion of a QOS 2 packet sent from the server.
|
||||
func (s *Server) processPubcomp(cl *Client, pk packets.Packet) error {
|
||||
// regardless of whether the pubcomp is a success or failure, we end the qos flow, delete inflight, and restore the quotas.
|
||||
cl.State.Inflight.ReturnReceiveQuota() // +1 RECV QUOTA
|
||||
cl.State.Inflight.ReturnSendQuota() // +1 SENT QUOTA
|
||||
cl.State.Inflight.IncreaseReceiveQuota() // +1 RECV QUOTA
|
||||
cl.State.Inflight.IncreaseSendQuota() // +1 SENT QUOTA
|
||||
if ok := cl.State.Inflight.Delete(pk.PacketID); ok {
|
||||
atomic.AddInt64(&s.Info.Inflight, -1)
|
||||
s.hooks.OnQosComplete(cl, pk)
|
||||
@@ -924,24 +992,24 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
code = packets.ErrPacketIdentifierInUse
|
||||
}
|
||||
|
||||
existed := false
|
||||
filterExisted := make([]bool, len(pk.Filters))
|
||||
reasonCodes := make([]byte, len(pk.Filters))
|
||||
for i, sub := range pk.Filters {
|
||||
if code != packets.CodeSuccess {
|
||||
reasonCodes[i] = code.Code // NB 3.9.3 Non-normative 0x91
|
||||
continue
|
||||
} else if !IsValidFilter(sub.Filter, false) {
|
||||
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
|
||||
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
|
||||
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
|
||||
} else if !s.hooks.OnACLCheck(cl, sub.Filter, false) {
|
||||
reasonCodes[i] = packets.ErrNotAuthorized.Code
|
||||
if s.Options.Capabilities.Compatibilities.ObscureNotAuthorized {
|
||||
reasonCodes[i] = packets.ErrUnspecifiedError.Code
|
||||
}
|
||||
} else if !IsValidFilter(sub.Filter, false) {
|
||||
reasonCodes[i] = packets.ErrTopicFilterInvalid.Code
|
||||
} else if sub.NoLocal && IsSharedFilter(sub.Filter) {
|
||||
reasonCodes[i] = packets.ErrProtocolViolationInvalidSharedNoLocal.Code // [MQTT-3.8.3-4]
|
||||
} else {
|
||||
existed = !s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||
if !existed {
|
||||
isNew := s.Topics.Subscribe(cl.ID, sub) // [MQTT-3.8.4-3]
|
||||
if isNew {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, 1)
|
||||
}
|
||||
cl.State.Subscriptions.Add(sub.Filter, sub) // [MQTT-3.2.2-10]
|
||||
@@ -950,6 +1018,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
sub.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
|
||||
}
|
||||
|
||||
filterExisted[i] = !isNew
|
||||
reasonCodes[i] = sub.Qos // [MQTT-3.9.3-1] [MQTT-3.8.4-7]
|
||||
}
|
||||
|
||||
@@ -984,7 +1053,7 @@ func (s *Server) processSubscribe(cl *Client, pk packets.Packet) error {
|
||||
continue
|
||||
}
|
||||
|
||||
s.publishRetainedToClient(cl, sub, existed)
|
||||
s.publishRetainedToClient(cl, sub, filterExisted[i])
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1034,20 +1103,32 @@ func (s *Server) processUnsubscribe(cl *Client, pk packets.Packet) error {
|
||||
return cl.WritePacket(ack)
|
||||
}
|
||||
|
||||
// unsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) unsubscribeClient(cl *Client) {
|
||||
for k := range cl.State.Subscriptions.GetAll() {
|
||||
// UnsubscribeClient unsubscribes a client from all of their subscriptions.
|
||||
func (s *Server) UnsubscribeClient(cl *Client) {
|
||||
i := 0
|
||||
filterMap := cl.State.Subscriptions.GetAll()
|
||||
filters := make([]packets.Subscription, len(filterMap))
|
||||
for k := range filterMap {
|
||||
cl.State.Subscriptions.Delete(k)
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range filterMap {
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
atomic.AddInt64(&s.Info.Subscriptions, -1)
|
||||
}
|
||||
filters[i] = v
|
||||
i++
|
||||
}
|
||||
s.hooks.OnUnsubscribed(cl, packets.Packet{Filters: filters})
|
||||
}
|
||||
|
||||
// processAuth processes an Auth packet.
|
||||
func (s *Server) processAuth(cl *Client, pk packets.Packet) error {
|
||||
_, err := s.hooks.OnAuthPacket(cl, pk)
|
||||
fmt.Println("err", err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1086,9 +1167,14 @@ func (s *Server) DisconnectClient(cl *Client, code packets.Code) error {
|
||||
out.Properties.ReasonString = code.Reason // // [MQTT-3.14.2-1]
|
||||
}
|
||||
|
||||
// We already have a code we are using to disconnect the client, so we are not
|
||||
// interested if the write packet fails due to a closed connection (as we are closing it).
|
||||
err := cl.WritePacket(out)
|
||||
if !s.Options.Capabilities.Compatibilities.PassiveClientDisconnect {
|
||||
cl.Stop(code)
|
||||
if code.Code >= packets.ErrUnspecifiedError.Code {
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -1130,6 +1216,7 @@ func (s *Server) publishSysTopics() {
|
||||
SysPrefix + "/broker/packets/sent": AtomicItoa(&s.Info.PacketsSent),
|
||||
SysPrefix + "/broker/messages/received": AtomicItoa(&s.Info.MessagesReceived),
|
||||
SysPrefix + "/broker/messages/sent": AtomicItoa(&s.Info.MessagesSent),
|
||||
SysPrefix + "/broker/messages/dropped": AtomicItoa(&s.Info.MessagesDropped),
|
||||
SysPrefix + "/broker/messages/inflight": AtomicItoa(&s.Info.Inflight),
|
||||
SysPrefix + "/broker/retained": AtomicItoa(&s.Info.Retained),
|
||||
SysPrefix + "/broker/subscriptions": AtomicItoa(&s.Info.Subscriptions),
|
||||
@@ -1151,8 +1238,6 @@ func (s *Server) publishSysTopics() {
|
||||
func (s *Server) Close() error {
|
||||
close(s.done)
|
||||
s.Listeners.CloseAll(s.closeListenerClients)
|
||||
s.fanpool.Close()
|
||||
s.fanpool.Wait()
|
||||
s.hooks.OnStopped()
|
||||
s.hooks.Stop()
|
||||
|
||||
@@ -1272,6 +1357,7 @@ func (s *Server) loadServerInfo(v system.Info) {
|
||||
atomic.StoreInt64(&s.Info.ClientsDisconnected, v.ClientsDisconnected)
|
||||
atomic.StoreInt64(&s.Info.MessagesReceived, v.MessagesReceived)
|
||||
atomic.StoreInt64(&s.Info.MessagesSent, v.MessagesSent)
|
||||
atomic.StoreInt64(&s.Info.MessagesDropped, v.MessagesDropped)
|
||||
atomic.StoreInt64(&s.Info.PacketsReceived, v.PacketsReceived)
|
||||
atomic.StoreInt64(&s.Info.PacketsSent, v.PacketsSent)
|
||||
atomic.StoreInt64(&s.Info.InflightDropped, v.InflightDropped)
|
||||
@@ -1303,9 +1389,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
|
||||
// loadClients restores clients from the datastore.
|
||||
func (s *Server) loadClients(v []storage.Client) {
|
||||
for _, c := range v {
|
||||
cl := newClientStub()
|
||||
cl.ID = c.ID
|
||||
cl.Net.Listener = c.Listener
|
||||
cl := s.NewClient(nil, c.Listener, c.ID, false)
|
||||
cl.Properties.Username = c.Username
|
||||
cl.Properties.Clean = c.Clean
|
||||
cl.Properties.ProtocolVersion = c.ProtocolVersion
|
||||
@@ -1331,25 +1415,7 @@ func (s *Server) loadClients(v []storage.Client) {
|
||||
func (s *Server) loadInflight(v []storage.Message) {
|
||||
for _, msg := range v {
|
||||
if client, ok := s.Clients.Get(msg.Origin); ok {
|
||||
client.State.Inflight.Set(packets.Packet{
|
||||
FixedHeader: msg.FixedHeader,
|
||||
PacketID: msg.PacketID,
|
||||
TopicName: msg.TopicName,
|
||||
Payload: msg.Payload,
|
||||
Origin: msg.Origin,
|
||||
Created: msg.Created,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: msg.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
|
||||
ContentType: msg.Properties.ContentType,
|
||||
ResponseTopic: msg.Properties.ResponseTopic,
|
||||
CorrelationData: msg.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: msg.Properties.TopicAlias,
|
||||
User: msg.Properties.User,
|
||||
},
|
||||
})
|
||||
client.State.Inflight.Set(msg.ToPacket())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1357,24 +1423,7 @@ func (s *Server) loadInflight(v []storage.Message) {
|
||||
// loadRetained restores retained messages from the datastore.
|
||||
func (s *Server) loadRetained(v []storage.Message) {
|
||||
for _, msg := range v {
|
||||
s.Topics.RetainMessage(packets.Packet{
|
||||
FixedHeader: msg.FixedHeader,
|
||||
TopicName: msg.TopicName,
|
||||
Payload: msg.Payload,
|
||||
Origin: msg.Origin,
|
||||
Created: msg.Created,
|
||||
Properties: packets.Properties{
|
||||
PayloadFormat: msg.Properties.PayloadFormat,
|
||||
PayloadFormatFlag: msg.Properties.PayloadFormatFlag,
|
||||
MessageExpiryInterval: msg.Properties.MessageExpiryInterval,
|
||||
ContentType: msg.Properties.ContentType,
|
||||
ResponseTopic: msg.Properties.ResponseTopic,
|
||||
CorrelationData: msg.Properties.CorrelationData,
|
||||
SubscriptionIdentifier: msg.Properties.SubscriptionIdentifier,
|
||||
TopicAlias: msg.Properties.TopicAlias,
|
||||
User: msg.Properties.User,
|
||||
},
|
||||
})
|
||||
s.Topics.RetainMessage(msg.ToPacket())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1412,8 +1461,10 @@ func (s *Server) clearExpiredRetainedMessages(now int64) {
|
||||
// clearExpiredInflights deletes any inflight messages which have expired.
|
||||
func (s *Server) clearExpiredInflights(now int64) {
|
||||
for _, client := range s.Clients.GetAll() {
|
||||
if d := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); d > 0 {
|
||||
s.hooks.OnExpireInflights(client, now)
|
||||
if deleted := client.ClearInflights(now, s.Options.Capabilities.MaximumMessageExpiryInterval); len(deleted) > 0 {
|
||||
for _, id := range deleted {
|
||||
s.hooks.OnQosDropped(client, packets.Packet{PacketID: id})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
541
server_test.go
541
server_test.go
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package system
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Info contains atomic counters and values for various server statistics
|
||||
// commonly found in $SYS topics (and others).
|
||||
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
|
||||
@@ -19,6 +22,7 @@ type Info struct {
|
||||
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
|
||||
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
|
||||
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
|
||||
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
|
||||
Retained int64 `json:"retained"` // total number of retained messages active on the broker
|
||||
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
|
||||
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
|
||||
@@ -28,3 +32,30 @@ type Info struct {
|
||||
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
|
||||
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
|
||||
}
|
||||
|
||||
// Clone makes a copy of Info using atomic operation
|
||||
func (i *Info) Clone() *Info {
|
||||
return &Info{
|
||||
Version: i.Version,
|
||||
Started: atomic.LoadInt64(&i.Started),
|
||||
Time: atomic.LoadInt64(&i.Time),
|
||||
Uptime: atomic.LoadInt64(&i.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&i.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
|
||||
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
|
||||
Retained: atomic.LoadInt64(&i.Retained),
|
||||
Inflight: atomic.LoadInt64(&i.Inflight),
|
||||
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
|
||||
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&i.Threads),
|
||||
}
|
||||
}
|
||||
|
||||
37
system/system_test.go
Normal file
37
system/system_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
o := &Info{
|
||||
Version: "version",
|
||||
Started: 1,
|
||||
Time: 2,
|
||||
Uptime: 3,
|
||||
BytesReceived: 4,
|
||||
BytesSent: 5,
|
||||
ClientsConnected: 6,
|
||||
ClientsMaximum: 7,
|
||||
ClientsTotal: 8,
|
||||
ClientsDisconnected: 9,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
Retained: 12,
|
||||
Inflight: 13,
|
||||
InflightDropped: 14,
|
||||
Subscriptions: 15,
|
||||
PacketsReceived: 16,
|
||||
PacketsSent: 17,
|
||||
MemoryAlloc: 18,
|
||||
Threads: 19,
|
||||
}
|
||||
|
||||
n := o.Clone()
|
||||
|
||||
require.Equal(t, o, n)
|
||||
}
|
||||
14
topics.go
14
topics.go
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
@@ -300,6 +301,9 @@ func NewTopicsIndex() *TopicsIndex {
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
prefix, _ := isolateParticle(subscription.Filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
@@ -319,6 +323,9 @@ func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription
|
||||
// Unsubscribe removes a subscription filter for a client, returning true if the
|
||||
// subscription existed.
|
||||
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var d int
|
||||
if strings.HasPrefix(filter, SharePrefix) {
|
||||
d = 2
|
||||
@@ -345,7 +352,12 @@ func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
n := x.set(pk.TopicName, 0)
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if len(pk.Payload) > 0 {
|
||||
n.retainPath = pk.TopicName
|
||||
x.Retained.Add(pk.TopicName, pk)
|
||||
@@ -360,6 +372,7 @@ func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
n.retainPath = ""
|
||||
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
|
||||
x.trim(n)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -618,6 +631,7 @@ type particle struct {
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
|
||||
2
vendor/golang.org/x/net/trace/histogram.go
generated
vendored
2
vendor/golang.org/x/net/trace/histogram.go
generated
vendored
@@ -32,7 +32,7 @@ type histogram struct {
|
||||
valueCount int64 // number of values recorded for single value
|
||||
}
|
||||
|
||||
// AddMeasurement records a value measurement observation to the histogram.
|
||||
// addMeasurement records a value measurement observation to the histogram.
|
||||
func (h *histogram) addMeasurement(value int64) {
|
||||
// TODO: assert invariant
|
||||
h.sum += value
|
||||
|
||||
2
vendor/golang.org/x/net/trace/trace.go
generated
vendored
2
vendor/golang.org/x/net/trace/trace.go
generated
vendored
@@ -395,7 +395,7 @@ func New(family, title string) Trace {
|
||||
}
|
||||
|
||||
func (tr *trace) Finish() {
|
||||
elapsed := time.Now().Sub(tr.Start)
|
||||
elapsed := time.Since(tr.Start)
|
||||
tr.mu.Lock()
|
||||
tr.Elapsed = elapsed
|
||||
tr.mu.Unlock()
|
||||
|
||||
30
vendor/golang.org/x/sys/internal/unsafeheader/unsafeheader.go
generated
vendored
30
vendor/golang.org/x/sys/internal/unsafeheader/unsafeheader.go
generated
vendored
@@ -1,30 +0,0 @@
|
||||
// Copyright 2020 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package unsafeheader contains header declarations for the Go runtime's
|
||||
// slice and string implementations.
|
||||
//
|
||||
// This package allows x/sys to use types equivalent to
|
||||
// reflect.SliceHeader and reflect.StringHeader without introducing
|
||||
// a dependency on the (relatively heavy) "reflect" package.
|
||||
package unsafeheader
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Slice is the runtime representation of a slice.
|
||||
// It cannot be used safely or portably and its representation may change in a later release.
|
||||
type Slice struct {
|
||||
Data unsafe.Pointer
|
||||
Len int
|
||||
Cap int
|
||||
}
|
||||
|
||||
// String is the runtime representation of a string.
|
||||
// It cannot be used safely or portably and its representation may change in a later release.
|
||||
type String struct {
|
||||
Data unsafe.Pointer
|
||||
Len int
|
||||
}
|
||||
31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s
generated
vendored
Normal file
31
vendor/golang.org/x/sys/unix/asm_bsd_ppc64.s
generated
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build (darwin || freebsd || netbsd || openbsd) && gc
|
||||
// +build darwin freebsd netbsd openbsd
|
||||
// +build gc
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
//
|
||||
// System call support for ppc64, BSD
|
||||
//
|
||||
|
||||
// Just jump to package syscall's implementation for all these functions.
|
||||
// The runtime may know about them.
|
||||
|
||||
TEXT ·Syscall(SB),NOSPLIT,$0-56
|
||||
JMP syscall·Syscall(SB)
|
||||
|
||||
TEXT ·Syscall6(SB),NOSPLIT,$0-80
|
||||
JMP syscall·Syscall6(SB)
|
||||
|
||||
TEXT ·Syscall9(SB),NOSPLIT,$0-104
|
||||
JMP syscall·Syscall9(SB)
|
||||
|
||||
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
|
||||
JMP syscall·RawSyscall(SB)
|
||||
|
||||
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
|
||||
JMP syscall·RawSyscall6(SB)
|
||||
4
vendor/golang.org/x/sys/unix/dirent.go
generated
vendored
4
vendor/golang.org/x/sys/unix/dirent.go
generated
vendored
@@ -2,8 +2,8 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
|
||||
|
||||
package unix
|
||||
|
||||
|
||||
4
vendor/golang.org/x/sys/unix/gccgo.go
generated
vendored
4
vendor/golang.org/x/sys/unix/gccgo.go
generated
vendored
@@ -2,8 +2,8 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gccgo && !aix
|
||||
// +build gccgo,!aix
|
||||
//go:build gccgo && !aix && !hurd
|
||||
// +build gccgo,!aix,!hurd
|
||||
|
||||
package unix
|
||||
|
||||
|
||||
4
vendor/golang.org/x/sys/unix/gccgo_c.c
generated
vendored
4
vendor/golang.org/x/sys/unix/gccgo_c.c
generated
vendored
@@ -2,8 +2,8 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build gccgo
|
||||
// +build !aix
|
||||
//go:build gccgo && !aix && !hurd
|
||||
// +build gccgo,!aix,!hurd
|
||||
|
||||
#include <errno.h>
|
||||
#include <stdint.h>
|
||||
|
||||
4
vendor/golang.org/x/sys/unix/ioctl.go
generated
vendored
4
vendor/golang.org/x/sys/unix/ioctl.go
generated
vendored
@@ -2,8 +2,8 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
|
||||
|
||||
package unix
|
||||
|
||||
|
||||
20
vendor/golang.org/x/sys/unix/ioctl_linux.go
generated
vendored
20
vendor/golang.org/x/sys/unix/ioctl_linux.go
generated
vendored
@@ -4,9 +4,7 @@
|
||||
|
||||
package unix
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
import "unsafe"
|
||||
|
||||
// IoctlRetInt performs an ioctl operation specified by req on a device
|
||||
// associated with opened file descriptor fd, and returns a non-negative
|
||||
@@ -217,3 +215,19 @@ func IoctlKCMAttach(fd int, info KCMAttach) error {
|
||||
func IoctlKCMUnattach(fd int, info KCMUnattach) error {
|
||||
return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info))
|
||||
}
|
||||
|
||||
// IoctlLoopGetStatus64 gets the status of the loop device associated with the
|
||||
// file descriptor fd using the LOOP_GET_STATUS64 operation.
|
||||
func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
|
||||
var value LoopInfo64
|
||||
if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
// IoctlLoopSetStatus64 sets the status of the loop device associated with the
|
||||
// file descriptor fd using the LOOP_SET_STATUS64 operation.
|
||||
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
|
||||
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
|
||||
}
|
||||
|
||||
49
vendor/golang.org/x/sys/unix/mkall.sh
generated
vendored
49
vendor/golang.org/x/sys/unix/mkall.sh
generated
vendored
@@ -73,12 +73,12 @@ aix_ppc64)
|
||||
darwin_amd64)
|
||||
mkerrors="$mkerrors -m64"
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
|
||||
mkasm="go run mkasm_darwin.go"
|
||||
mkasm="go run mkasm.go"
|
||||
;;
|
||||
darwin_arm64)
|
||||
mkerrors="$mkerrors -m64"
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
|
||||
mkasm="go run mkasm_darwin.go"
|
||||
mkasm="go run mkasm.go"
|
||||
;;
|
||||
dragonfly_amd64)
|
||||
mkerrors="$mkerrors -m64"
|
||||
@@ -142,42 +142,60 @@ netbsd_arm64)
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
|
||||
;;
|
||||
openbsd_386)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m32"
|
||||
mksyscall="go run mksyscall.go -l32 -openbsd"
|
||||
mksyscall="go run mksyscall.go -l32 -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
|
||||
;;
|
||||
openbsd_amd64)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m64"
|
||||
mksyscall="go run mksyscall.go -openbsd"
|
||||
mksyscall="go run mksyscall.go -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs"
|
||||
;;
|
||||
openbsd_arm)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors"
|
||||
mksyscall="go run mksyscall.go -l32 -openbsd -arm"
|
||||
mksyscall="go run mksyscall.go -l32 -openbsd -arm -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
|
||||
# Let the type of C char be signed for making the bare syscall
|
||||
# API consistent across platforms.
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
|
||||
;;
|
||||
openbsd_arm64)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m64"
|
||||
mksyscall="go run mksyscall.go -openbsd"
|
||||
mksyscall="go run mksyscall.go -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
|
||||
# Let the type of C char be signed for making the bare syscall
|
||||
# API consistent across platforms.
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
|
||||
;;
|
||||
openbsd_mips64)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m64"
|
||||
mksyscall="go run mksyscall.go -openbsd"
|
||||
mksyscall="go run mksyscall.go -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
# Let the type of C char be signed for making the bare syscall
|
||||
# API consistent across platforms.
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
|
||||
;;
|
||||
openbsd_ppc64)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m64"
|
||||
mksyscall="go run mksyscall.go -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
# Let the type of C char be signed for making the bare syscall
|
||||
# API consistent across platforms.
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
|
||||
;;
|
||||
openbsd_riscv64)
|
||||
mkasm="go run mkasm.go"
|
||||
mkerrors="$mkerrors -m64"
|
||||
mksyscall="go run mksyscall.go -openbsd -libc"
|
||||
mksysctl="go run mksysctl_openbsd.go"
|
||||
mksysnum="go run mksysnum.go 'https://cvsweb.openbsd.org/cgi-bin/cvsweb/~checkout~/src/sys/kern/syscalls.master'"
|
||||
# Let the type of C char be signed for making the bare syscall
|
||||
# API consistent across platforms.
|
||||
mktypes="GOARCH=$GOARCH go tool cgo -godefs -- -fsigned-char"
|
||||
@@ -214,11 +232,6 @@ esac
|
||||
if [ "$GOOSARCH" == "aix_ppc64" ]; then
|
||||
# aix/ppc64 script generates files instead of writing to stdin.
|
||||
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ;
|
||||
elif [ "$GOOS" == "darwin" ]; then
|
||||
# 1.12 and later, syscalls via libSystem
|
||||
echo "$mksyscall -tags $GOOS,$GOARCH,go1.12 $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go";
|
||||
# 1.13 and later, syscalls via libSystem (including syscallPtr)
|
||||
echo "$mksyscall -tags $GOOS,$GOARCH,go1.13 syscall_darwin.1_13.go |gofmt >zsyscall_$GOOSARCH.1_13.go";
|
||||
elif [ "$GOOS" == "illumos" ]; then
|
||||
# illumos code generation requires a --illumos switch
|
||||
echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go";
|
||||
@@ -232,5 +245,5 @@ esac
|
||||
if [ -n "$mksysctl" ]; then echo "$mksysctl |gofmt >$zsysctl"; fi
|
||||
if [ -n "$mksysnum" ]; then echo "$mksysnum |gofmt >zsysnum_$GOOSARCH.go"; fi
|
||||
if [ -n "$mktypes" ]; then echo "$mktypes types_$GOOS.go | go run mkpost.go > ztypes_$GOOSARCH.go"; fi
|
||||
if [ -n "$mkasm" ]; then echo "$mkasm $GOARCH"; fi
|
||||
if [ -n "$mkasm" ]; then echo "$mkasm $GOOS $GOARCH"; fi
|
||||
) | $run
|
||||
|
||||
4
vendor/golang.org/x/sys/unix/mkerrors.sh
generated
vendored
4
vendor/golang.org/x/sys/unix/mkerrors.sh
generated
vendored
@@ -642,7 +642,7 @@ errors=$(
|
||||
signals=$(
|
||||
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
|
||||
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
|
||||
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
|
||||
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
|
||||
sort
|
||||
)
|
||||
|
||||
@@ -652,7 +652,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
|
||||
sort >_error.grep
|
||||
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
|
||||
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
|
||||
egrep -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
|
||||
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
|
||||
sort >_signal.grep
|
||||
|
||||
echo '// mkerrors.sh' "$@"
|
||||
|
||||
14
vendor/golang.org/x/sys/unix/sockcmsg_unix.go
generated
vendored
14
vendor/golang.org/x/sys/unix/sockcmsg_unix.go
generated
vendored
@@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
|
||||
// message data (a slice of b), and the remainder of b after that single message.
|
||||
// When there are no remaining messages, len(remainder) == 0.
|
||||
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
|
||||
h, dbuf, err := socketControlMessageHeaderAndData(b)
|
||||
if err != nil {
|
||||
return Cmsghdr{}, nil, nil, err
|
||||
}
|
||||
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
|
||||
remainder = b[i:]
|
||||
}
|
||||
return *h, dbuf, remainder, nil
|
||||
}
|
||||
|
||||
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
|
||||
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
|
||||
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {
|
||||
|
||||
27
vendor/golang.org/x/sys/unix/str.go
generated
vendored
27
vendor/golang.org/x/sys/unix/str.go
generated
vendored
@@ -1,27 +0,0 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||
|
||||
package unix
|
||||
|
||||
func itoa(val int) string { // do it here rather than with fmt to avoid dependency
|
||||
if val < 0 {
|
||||
return "-" + uitoa(uint(-val))
|
||||
}
|
||||
return uitoa(uint(val))
|
||||
}
|
||||
|
||||
func uitoa(val uint) string {
|
||||
var buf [32]byte // big enough for int64
|
||||
i := len(buf) - 1
|
||||
for val >= 10 {
|
||||
buf[i] = byte(val%10 + '0')
|
||||
i--
|
||||
val /= 10
|
||||
}
|
||||
buf[i] = byte(val + '0')
|
||||
return string(buf[i:])
|
||||
}
|
||||
10
vendor/golang.org/x/sys/unix/syscall.go
generated
vendored
10
vendor/golang.org/x/sys/unix/syscall.go
generated
vendored
@@ -29,8 +29,6 @@ import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/internal/unsafeheader"
|
||||
)
|
||||
|
||||
// ByteSliceFromString returns a NUL-terminated slice of bytes
|
||||
@@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string {
|
||||
ptr = unsafe.Pointer(uintptr(ptr) + 1)
|
||||
}
|
||||
|
||||
var s []byte
|
||||
h := (*unsafeheader.Slice)(unsafe.Pointer(&s))
|
||||
h.Data = unsafe.Pointer(p)
|
||||
h.Len = n
|
||||
h.Cap = n
|
||||
|
||||
return string(s)
|
||||
return string(unsafe.Slice(p, n))
|
||||
}
|
||||
|
||||
// Single-word zero for use when we need a valid pointer to 0 bytes.
|
||||
|
||||
2
vendor/golang.org/x/sys/unix/syscall_aix.go
generated
vendored
2
vendor/golang.org/x/sys/unix/syscall_aix.go
generated
vendored
@@ -253,7 +253,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
|
||||
var empty bool
|
||||
if len(oob) > 0 {
|
||||
// send at least one normal byte
|
||||
empty := emptyIovecs(iov)
|
||||
empty = emptyIovecs(iov)
|
||||
if empty {
|
||||
var iova [1]Iovec
|
||||
iova[0].Base = &dummy
|
||||
|
||||
2
vendor/golang.org/x/sys/unix/syscall_bsd.go
generated
vendored
2
vendor/golang.org/x/sys/unix/syscall_bsd.go
generated
vendored
@@ -363,7 +363,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
|
||||
var empty bool
|
||||
if len(oob) > 0 {
|
||||
// send at least one normal byte
|
||||
empty := emptyIovecs(iov)
|
||||
empty = emptyIovecs(iov)
|
||||
if empty {
|
||||
var iova [1]Iovec
|
||||
iova[0].Base = &dummy
|
||||
|
||||
32
vendor/golang.org/x/sys/unix/syscall_darwin.1_12.go
generated
vendored
32
vendor/golang.org/x/sys/unix/syscall_darwin.1_12.go
generated
vendored
@@ -1,32 +0,0 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build darwin && go1.12 && !go1.13
|
||||
// +build darwin,go1.12,!go1.13
|
||||
|
||||
package unix
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const _SYS_GETDIRENTRIES64 = 344
|
||||
|
||||
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
|
||||
// To implement this using libSystem we'd need syscall_syscallPtr for
|
||||
// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall
|
||||
// back to raw syscalls for this func on Go 1.12.
|
||||
var p unsafe.Pointer
|
||||
if len(buf) > 0 {
|
||||
p = unsafe.Pointer(&buf[0])
|
||||
} else {
|
||||
p = unsafe.Pointer(&_zero)
|
||||
}
|
||||
r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0)
|
||||
n = int(r0)
|
||||
if e1 != 0 {
|
||||
return n, errnoErr(e1)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
108
vendor/golang.org/x/sys/unix/syscall_darwin.1_13.go
generated
vendored
108
vendor/golang.org/x/sys/unix/syscall_darwin.1_13.go
generated
vendored
@@ -1,108 +0,0 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build darwin && go1.13
|
||||
// +build darwin,go1.13
|
||||
|
||||
package unix
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/internal/unsafeheader"
|
||||
)
|
||||
|
||||
//sys closedir(dir uintptr) (err error)
|
||||
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
|
||||
|
||||
func fdopendir(fd int) (dir uintptr, err error) {
|
||||
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
|
||||
dir = uintptr(r0)
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var libc_fdopendir_trampoline_addr uintptr
|
||||
|
||||
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
|
||||
|
||||
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
|
||||
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
|
||||
// We store the number of entries to skip in the seek
|
||||
// offset of fd. See issue #31368.
|
||||
// It's not the full required semantics, but should handle the case
|
||||
// of calling Getdirentries or ReadDirent repeatedly.
|
||||
// It won't handle assigning the results of lseek to *basep, or handle
|
||||
// the directory being edited underfoot.
|
||||
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We need to duplicate the incoming file descriptor
|
||||
// because the caller expects to retain control of it, but
|
||||
// fdopendir expects to take control of its argument.
|
||||
// Just Dup'ing the file descriptor is not enough, as the
|
||||
// result shares underlying state. Use Openat to make a really
|
||||
// new file descriptor referring to the same directory.
|
||||
fd2, err := Openat(fd, ".", O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
d, err := fdopendir(fd2)
|
||||
if err != nil {
|
||||
Close(fd2)
|
||||
return 0, err
|
||||
}
|
||||
defer closedir(d)
|
||||
|
||||
var cnt int64
|
||||
for {
|
||||
var entry Dirent
|
||||
var entryp *Dirent
|
||||
e := readdir_r(d, &entry, &entryp)
|
||||
if e != 0 {
|
||||
return n, errnoErr(e)
|
||||
}
|
||||
if entryp == nil {
|
||||
break
|
||||
}
|
||||
if skip > 0 {
|
||||
skip--
|
||||
cnt++
|
||||
continue
|
||||
}
|
||||
|
||||
reclen := int(entry.Reclen)
|
||||
if reclen > len(buf) {
|
||||
// Not enough room. Return for now.
|
||||
// The counter will let us know where we should start up again.
|
||||
// Note: this strategy for suspending in the middle and
|
||||
// restarting is O(n^2) in the length of the directory. Oh well.
|
||||
break
|
||||
}
|
||||
|
||||
// Copy entry into return buffer.
|
||||
var s []byte
|
||||
hdr := (*unsafeheader.Slice)(unsafe.Pointer(&s))
|
||||
hdr.Data = unsafe.Pointer(&entry)
|
||||
hdr.Cap = reclen
|
||||
hdr.Len = reclen
|
||||
copy(buf, s)
|
||||
|
||||
buf = buf[reclen:]
|
||||
n += reclen
|
||||
cnt++
|
||||
}
|
||||
// Set the seek offset of the input fd to record
|
||||
// how many files we've already returned.
|
||||
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
91
vendor/golang.org/x/sys/unix/syscall_darwin.go
generated
vendored
91
vendor/golang.org/x/sys/unix/syscall_darwin.go
generated
vendored
@@ -19,6 +19,96 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
//sys closedir(dir uintptr) (err error)
|
||||
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
|
||||
|
||||
func fdopendir(fd int) (dir uintptr, err error) {
|
||||
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
|
||||
dir = uintptr(r0)
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var libc_fdopendir_trampoline_addr uintptr
|
||||
|
||||
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
|
||||
|
||||
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
|
||||
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
|
||||
// We store the number of entries to skip in the seek
|
||||
// offset of fd. See issue #31368.
|
||||
// It's not the full required semantics, but should handle the case
|
||||
// of calling Getdirentries or ReadDirent repeatedly.
|
||||
// It won't handle assigning the results of lseek to *basep, or handle
|
||||
// the directory being edited underfoot.
|
||||
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// We need to duplicate the incoming file descriptor
|
||||
// because the caller expects to retain control of it, but
|
||||
// fdopendir expects to take control of its argument.
|
||||
// Just Dup'ing the file descriptor is not enough, as the
|
||||
// result shares underlying state. Use Openat to make a really
|
||||
// new file descriptor referring to the same directory.
|
||||
fd2, err := Openat(fd, ".", O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
d, err := fdopendir(fd2)
|
||||
if err != nil {
|
||||
Close(fd2)
|
||||
return 0, err
|
||||
}
|
||||
defer closedir(d)
|
||||
|
||||
var cnt int64
|
||||
for {
|
||||
var entry Dirent
|
||||
var entryp *Dirent
|
||||
e := readdir_r(d, &entry, &entryp)
|
||||
if e != 0 {
|
||||
return n, errnoErr(e)
|
||||
}
|
||||
if entryp == nil {
|
||||
break
|
||||
}
|
||||
if skip > 0 {
|
||||
skip--
|
||||
cnt++
|
||||
continue
|
||||
}
|
||||
|
||||
reclen := int(entry.Reclen)
|
||||
if reclen > len(buf) {
|
||||
// Not enough room. Return for now.
|
||||
// The counter will let us know where we should start up again.
|
||||
// Note: this strategy for suspending in the middle and
|
||||
// restarting is O(n^2) in the length of the directory. Oh well.
|
||||
break
|
||||
}
|
||||
|
||||
// Copy entry into return buffer.
|
||||
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
|
||||
copy(buf, s)
|
||||
|
||||
buf = buf[reclen:]
|
||||
n += reclen
|
||||
cnt++
|
||||
}
|
||||
// Set the seek offset of the input fd to record
|
||||
// how many files we've already returned.
|
||||
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets.
|
||||
type SockaddrDatalink struct {
|
||||
Len uint8
|
||||
@@ -140,6 +230,7 @@ func direntNamlen(buf []byte) (uint64, bool) {
|
||||
|
||||
func PtraceAttach(pid int) (err error) { return ptrace(PT_ATTACH, pid, 0, 0) }
|
||||
func PtraceDetach(pid int) (err error) { return ptrace(PT_DETACH, pid, 0, 0) }
|
||||
func PtraceDenyAttach() (err error) { return ptrace(PT_DENY_ATTACH, 0, 0, 0) }
|
||||
|
||||
//sysnb pipe(p *[2]int32) (err error)
|
||||
|
||||
|
||||
1
vendor/golang.org/x/sys/unix/syscall_dragonfly.go
generated
vendored
1
vendor/golang.org/x/sys/unix/syscall_dragonfly.go
generated
vendored
@@ -255,6 +255,7 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
|
||||
//sys Chmod(path string, mode uint32) (err error)
|
||||
//sys Chown(path string, uid int, gid int) (err error)
|
||||
//sys Chroot(path string) (err error)
|
||||
//sys ClockGettime(clockid int32, time *Timespec) (err error)
|
||||
//sys Close(fd int) (err error)
|
||||
//sys Dup(fd int) (nfd int, err error)
|
||||
//sys Dup2(from int, to int) (err error)
|
||||
|
||||
1
vendor/golang.org/x/sys/unix/syscall_freebsd.go
generated
vendored
1
vendor/golang.org/x/sys/unix/syscall_freebsd.go
generated
vendored
@@ -319,6 +319,7 @@ func PtraceSingleStep(pid int) (err error) {
|
||||
//sys Chmod(path string, mode uint32) (err error)
|
||||
//sys Chown(path string, uid int, gid int) (err error)
|
||||
//sys Chroot(path string) (err error)
|
||||
//sys ClockGettime(clockid int32, time *Timespec) (err error)
|
||||
//sys Close(fd int) (err error)
|
||||
//sys Dup(fd int) (nfd int, err error)
|
||||
//sys Dup2(from int, to int) (err error)
|
||||
|
||||
9
vendor/golang.org/x/sys/unix/syscall_freebsd_386.go
generated
vendored
9
vendor/golang.org/x/sys/unix/syscall_freebsd_386.go
generated
vendored
@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
|
||||
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
|
||||
}
|
||||
|
||||
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
|
||||
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{
|
||||
Op: int32(req),
|
||||
Offs: offs,
|
||||
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
|
||||
Len: uint32(countin),
|
||||
}
|
||||
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
|
||||
return int(ioDesc.Len), err
|
||||
}
|
||||
|
||||
9
vendor/golang.org/x/sys/unix/syscall_freebsd_amd64.go
generated
vendored
9
vendor/golang.org/x/sys/unix/syscall_freebsd_amd64.go
generated
vendored
@@ -60,8 +60,13 @@ func PtraceGetFsBase(pid int, fsbase *int64) (err error) {
|
||||
return ptrace(PT_GETFSBASE, pid, uintptr(unsafe.Pointer(fsbase)), 0)
|
||||
}
|
||||
|
||||
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
|
||||
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{
|
||||
Op: int32(req),
|
||||
Offs: offs,
|
||||
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
|
||||
Len: uint64(countin),
|
||||
}
|
||||
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
|
||||
return int(ioDesc.Len), err
|
||||
}
|
||||
|
||||
9
vendor/golang.org/x/sys/unix/syscall_freebsd_arm.go
generated
vendored
9
vendor/golang.org/x/sys/unix/syscall_freebsd_arm.go
generated
vendored
@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
|
||||
|
||||
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
|
||||
|
||||
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint32(countin)}
|
||||
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{
|
||||
Op: int32(req),
|
||||
Offs: offs,
|
||||
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
|
||||
Len: uint32(countin),
|
||||
}
|
||||
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
|
||||
return int(ioDesc.Len), err
|
||||
}
|
||||
|
||||
9
vendor/golang.org/x/sys/unix/syscall_freebsd_arm64.go
generated
vendored
9
vendor/golang.org/x/sys/unix/syscall_freebsd_arm64.go
generated
vendored
@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
|
||||
|
||||
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
|
||||
|
||||
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
|
||||
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{
|
||||
Op: int32(req),
|
||||
Offs: offs,
|
||||
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
|
||||
Len: uint64(countin),
|
||||
}
|
||||
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
|
||||
return int(ioDesc.Len), err
|
||||
}
|
||||
|
||||
9
vendor/golang.org/x/sys/unix/syscall_freebsd_riscv64.go
generated
vendored
9
vendor/golang.org/x/sys/unix/syscall_freebsd_riscv64.go
generated
vendored
@@ -56,8 +56,13 @@ func sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
|
||||
|
||||
func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno)
|
||||
|
||||
func PtraceIO(req int, pid int, addr uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{Op: int32(req), Offs: (*byte)(unsafe.Pointer(addr)), Addr: (*byte)(unsafe.Pointer(&out[0])), Len: uint64(countin)}
|
||||
func PtraceIO(req int, pid int, offs uintptr, out []byte, countin int) (count int, err error) {
|
||||
ioDesc := PtraceIoDesc{
|
||||
Op: int32(req),
|
||||
Offs: offs,
|
||||
Addr: uintptr(unsafe.Pointer(&out[0])), // TODO(#58351): this is not safe.
|
||||
Len: uint64(countin),
|
||||
}
|
||||
err = ptrace(PT_IO, pid, uintptr(unsafe.Pointer(&ioDesc)), 0)
|
||||
return int(ioDesc.Len), err
|
||||
}
|
||||
|
||||
22
vendor/golang.org/x/sys/unix/syscall_hurd.go
generated
vendored
Normal file
22
vendor/golang.org/x/sys/unix/syscall_hurd.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build hurd
|
||||
// +build hurd
|
||||
|
||||
package unix
|
||||
|
||||
/*
|
||||
#include <stdint.h>
|
||||
int ioctl(int, unsigned long int, uintptr_t);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
func ioctl(fd int, req uint, arg uintptr) (err error) {
|
||||
r0, er := C.ioctl(C.int(fd), C.ulong(req), C.uintptr_t(arg))
|
||||
if r0 == -1 && er != nil {
|
||||
err = er
|
||||
}
|
||||
return
|
||||
}
|
||||
29
vendor/golang.org/x/sys/unix/syscall_hurd_386.go
generated
vendored
Normal file
29
vendor/golang.org/x/sys/unix/syscall_hurd_386.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build 386 && hurd
|
||||
// +build 386,hurd
|
||||
|
||||
package unix
|
||||
|
||||
const (
|
||||
TIOCGETA = 0x62251713
|
||||
)
|
||||
|
||||
type Winsize struct {
|
||||
Row uint16
|
||||
Col uint16
|
||||
Xpixel uint16
|
||||
Ypixel uint16
|
||||
}
|
||||
|
||||
type Termios struct {
|
||||
Iflag uint32
|
||||
Oflag uint32
|
||||
Cflag uint32
|
||||
Lflag uint32
|
||||
Cc [20]uint8
|
||||
Ispeed int32
|
||||
Ospeed int32
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user