Implement File based configuration (#351)

* Implement file-based configuration

* Implement file-based configuration

* Replace DefaultServerCapabilities with NewDefaultServerCapabilities() to avoid data race (#360)

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>

* Only pass a copy of system.Info to hooks (#365)

* Only pass a copy of system.Info to hooks

* Rename Itoa to Int64toa

---------

Co-authored-by: JB <28275108+mochi-co@users.noreply.github.com>

* Allow configurable max stored qos > 0 messages (#359)

* Allow configurable max stored qos > 0 messages

* Only rollback Inflight if QoS > 0

* Only rollback Inflight if QoS > 0

* Minor refactor

* Update server version

* Implement file-based configuration

* Implement file-based configuration

* update configs with maximum_inflight value

* update docker configuration

* fix tests

---------

Co-authored-by: mochi-co <moumochi@icloud.com>
Co-authored-by: thedevop <60499013+thedevop@users.noreply.github.com>
This commit is contained in:
JB
2024-03-18 03:28:12 +00:00
committed by GitHub
parent 26720c2f6e
commit 26418c6fd8
44 changed files with 1160 additions and 219 deletions

View File

@@ -11,21 +11,12 @@ RUN go mod download
COPY . ./ COPY . ./
RUN go build -o /app/mochi ./cmd RUN go build -o /app/mochi ./cmd/docker
FROM alpine FROM alpine
WORKDIR / WORKDIR /
COPY --from=builder /app/mochi . COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ] ENTRYPOINT [ "/mochi" ]
CMD ["/cmd/docker", "--config", "config.yaml"]

View File

@@ -60,7 +60,6 @@ Unless it's a critical issue, new releases typically go out over the weekend.
- Please [open an issue](https://github.com/mochi-mqtt/server/issues) to request new features or event hooks! - Please [open an issue](https://github.com/mochi-mqtt/server/issues) to request new features or event hooks!
- Cluster support. - Cluster support.
- Enhanced Metrics support. - Enhanced Metrics support.
- File-based server configuration (supporting docker).
## Quick Start ## Quick Start
### Running the Broker with Go ### Running the Broker with Go
@@ -77,18 +76,50 @@ You can now pull and run the [official Mochi MQTT image](https://hub.docker.com/
```sh ```sh
docker pull mochimqtt/server docker pull mochimqtt/server
or or
docker run mochimqtt/server docker run -v $(pwd)/config.yaml:/config.yaml mochimqtt/server
``` ```
This is a work in progress, and a [file-based configuration](https://github.com/orgs/mochi-mqtt/projects/2) is being developed to better support this implementation. _More substantial docker support is being discussed [here](https://github.com/orgs/mochi-mqtt/discussions/281#discussion-5544545) and [here](https://github.com/orgs/mochi-mqtt/discussions/209). Please join the discussion if you use Mochi-MQTT in this environment._ For most use cases, you can use File Based Configuration to configure the server, by specifying a valid `yaml` or `json` config file.
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server: A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server, using the `allow-all` auth hook.
```sh ```sh
docker build -t mochi:latest . docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 -v $(pwd)/config.yaml:/config.yaml mochi:latest
``` ```
### File Based Configuration
You can use File Based Configuration with either the Docker image (described above), or by running the build binary with the `--config=config.yaml` or `--config=config.json` parameter.
Configuration files provide a convenient mechanism for easily preparing a server with the most common configurations. You can enable and configure built-in hooks and listeners, and specify server options and compatibilities:
```yaml
listeners:
- type: "tcp"
id: "tcp12"
address: ":1883"
- type: "ws"
id: "ws1"
address: ":1882"
- type: "sysinfo"
id: "stats"
address: ":1880"
hooks:
auth:
allow_all: true
options:
inline_client: true
```
Please review the examples found in `examples/config` for all available configuration options.
There are a few conditions to note:
1. If you use file-based configuration, you can only have one of each hook type.
2. You can only use built in hooks with file-based configuration, as the type and configuration structure needs to be known by the server in order for it to be applied.
3. You can only use built in listeners, for the reasons above.
If you need to implement custom hooks or listeners, please do so using the traditional manner indicated in `cmd/main.go`.
## Developing with Mochi MQTT ## Developing with Mochi MQTT
### Importing as a package ### Importing as a package
Importing Mochi MQTT as a package requires just a few lines of code to get started. Importing Mochi MQTT as a package requires just a few lines of code to get started.

66
cmd/docker/main.go Normal file
View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt
// SPDX-FileContributor: dgduncan, mochi-co
package main
import (
"flag"
"fmt"
"github.com/mochi-mqtt/server/v2/config"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-mqtt/server/v2"
)
func main() {
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, nil))) // set basic logger to ensure logs before configuration are in a consistent format
configFile := flag.String("config", "config.yaml", "path to mochi config yaml or json file")
flag.Parse()
entries, err := os.ReadDir("./")
if err != nil {
log.Fatal(err)
}
for _, e := range entries {
fmt.Println(e.Name())
}
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
configBytes, err := os.ReadFile(*configFile)
if err != nil {
log.Fatal(err)
}
options, err := config.FromBytes(configBytes)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(options)
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("mochi mqtt shutdown complete")
}

View File

@@ -33,19 +33,31 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: *tcpAddr,
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
ws := listeners.NewWebsocket("ws1", *wsAddr, nil) ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: *wsAddr,
})
err = server.AddListener(ws) err = server.AddListener(ws)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info) stats := listeners.NewHTTPStats(
listeners.Config{
ID: "info",
Address: *infoAddr,
},
server.Info,
)
err = server.AddListener(stats) err = server.AddListener(stats)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@@ -61,6 +73,5 @@ func main() {
<-done <-done
server.Log.Warn("caught signal, stopping...") server.Log.Warn("caught signal, stopping...")
_ = server.Close() _ = server.Close()
server.Log.Info("main.go finished") server.Log.Info("mochi mqtt shutdown complete")
} }

15
config.yaml Normal file
View File

@@ -0,0 +1,15 @@
listeners:
- type: "tcp"
id: "tcp12"
address: ":1883"
- type: "ws"
id: "ws1"
address: ":1882"
- type: "sysinfo"
id: "stats"
address: ":1880"
hooks:
auth:
allow_all: true
options:
inline_client: true

144
config/config.go Normal file
View File

@@ -0,0 +1,144 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package config
import (
"encoding/json"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/debug"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
"github.com/mochi-mqtt/server/v2/listeners"
"gopkg.in/yaml.v3"
mqtt "github.com/mochi-mqtt/server/v2"
)
// config defines the structure of configuration data to be parsed from a config source.
type config struct {
Options mqtt.Options
Listeners []listeners.Config `yaml:"listeners" json:"listeners"`
HookConfigs HookConfigs `yaml:"hooks" json:"hooks"`
}
// HookConfigs contains configurations to enable individual hooks.
type HookConfigs struct {
Auth *HookAuthConfig `yaml:"auth" json:"auth"`
Storage *HookStorageConfig `yaml:"storage" json:"storage"`
Debug *debug.Options `yaml:"debug" json:"debug"`
}
// HookAuthConfig contains configurations for the auth hook.
type HookAuthConfig struct {
Ledger auth.Ledger `yaml:"ledger" json:"ledger"`
AllowAll bool `yaml:"allow_all" json:"allow_all"`
}
// HookStorageConfig contains configurations for the different storage hooks.
type HookStorageConfig struct {
Badger *badger.Options `yaml:"badger" json:"badger"`
Bolt *bolt.Options `yaml:"bolt" json:"bolt"`
Redis *redis.Options `yaml:"redis" json:"redis"`
}
// ToHooks converts Hook file configurations into Hooks to be added to the server.
func (hc HookConfigs) ToHooks() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Auth != nil {
hlc = append(hlc, hc.toHooksAuth()...)
}
if hc.Storage != nil {
hlc = append(hlc, hc.toHooksStorage()...)
}
if hc.Debug != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(debug.Hook),
Config: hc.Debug,
})
}
return hlc
}
// toHooksAuth converts auth hook configurations into auth hooks.
func (hc HookConfigs) toHooksAuth() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Auth.AllowAll {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(auth.AllowHook),
})
} else {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(auth.Hook),
Config: &auth.Options{
Ledger: &auth.Ledger{ // avoid copying sync.Locker
Users: hc.Auth.Ledger.Users,
Auth: hc.Auth.Ledger.Auth,
ACL: hc.Auth.Ledger.ACL,
},
},
})
}
return hlc
}
// toHooksAuth converts storage hook configurations into storage hooks.
func (hc HookConfigs) toHooksStorage() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Storage.Badger != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(badger.Hook),
Config: hc.Storage.Badger,
})
}
if hc.Storage.Bolt != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(bolt.Hook),
Config: hc.Storage.Bolt,
})
}
if hc.Storage.Redis != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(redis.Hook),
Config: hc.Storage.Redis,
})
}
return hlc
}
// FromBytes unmarshals a byte slice of JSON or YAML config data into a valid server options value.
// Any hooks configurations are converted into Hooks using the toHooks methods in this package.
func FromBytes(b []byte) (*mqtt.Options, error) {
c := new(config)
o := mqtt.Options{}
if len(b) == 0 {
return nil, nil
}
if b[0] == '{' {
err := json.Unmarshal(b, c)
if err != nil {
return nil, err
}
} else {
err := yaml.Unmarshal(b, c)
if err != nil {
return nil, err
}
}
o = c.Options
o.Hooks = c.HookConfigs.ToHooks()
o.Listeners = c.Listeners
return &o, nil
}

212
config/config_test.go Normal file
View File

@@ -0,0 +1,212 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package config
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
"github.com/mochi-mqtt/server/v2/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
)
var (
yamlBytes = []byte(`
listeners:
- type: "tcp"
id: "file-tcp1"
address: ":1883"
hooks:
auth:
allow_all: true
options:
client_net_write_buffer_size: 2048
capabilities:
minimum_protocol_version: 3
compatibilities:
restore_sys_info_on_restart: true
`)
jsonBytes = []byte(`{
"listeners": [
{
"type": "tcp",
"id": "file-tcp1",
"address": ":1883"
}
],
"hooks": {
"auth": {
"allow_all": true
}
},
"options": {
"client_net_write_buffer_size": 2048,
"capabilities": {
"minimum_protocol_version": 3,
"compatibilities": {
"restore_sys_info_on_restart": true
}
}
}
}
`)
parsedOptions = mqtt.Options{
Listeners: []listeners.Config{
{
Type: listeners.TypeTCP,
ID: "file-tcp1",
Address: ":1883",
},
},
Hooks: []mqtt.HookLoadConfig{
{
Hook: new(auth.AllowHook),
},
},
ClientNetWriteBufferSize: 2048,
Capabilities: &mqtt.Capabilities{
MinimumProtocolVersion: 3,
Compatibilities: mqtt.Compatibilities{
RestoreSysInfoOnRestart: true,
},
},
}
)
func TestFromBytesEmptyL(t *testing.T) {
_, err := FromBytes([]byte{})
require.NoError(t, err)
}
func TestFromBytesYAML(t *testing.T) {
o, err := FromBytes(yamlBytes)
require.NoError(t, err)
require.Equal(t, parsedOptions, *o)
}
func TestFromBytesYAMLError(t *testing.T) {
_, err := FromBytes(append(yamlBytes, 'a'))
require.Error(t, err)
}
func TestFromBytesJSON(t *testing.T) {
o, err := FromBytes(jsonBytes)
require.NoError(t, err)
require.Equal(t, parsedOptions, *o)
}
func TestFromBytesJSONError(t *testing.T) {
_, err := FromBytes(append(jsonBytes, 'a'))
require.Error(t, err)
}
func TestToHooksAuthAllowAll(t *testing.T) {
hc := HookConfigs{
Auth: &HookAuthConfig{
AllowAll: true,
},
}
th := hc.toHooksAuth()
expect := []mqtt.HookLoadConfig{
{Hook: new(auth.AllowHook)},
}
require.Equal(t, expect, th)
}
func TestToHooksAuthAllowLedger(t *testing.T) {
hc := HookConfigs{
Auth: &HookAuthConfig{
Ledger: auth.Ledger{
Auth: auth.AuthRules{
{Username: "peach", Password: "password1", Allow: true},
},
},
},
}
th := hc.toHooksAuth()
expect := []mqtt.HookLoadConfig{
{
Hook: new(auth.Hook),
Config: &auth.Options{
Ledger: &auth.Ledger{ // avoid copying sync.Locker
Auth: auth.AuthRules{
{Username: "peach", Password: "password1", Allow: true},
},
},
},
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageBadger(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Badger: &badger.Options{
Path: "badger",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(badger.Hook),
Config: hc.Storage.Badger,
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageBolt(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Bolt: &bolt.Options{
Path: "bolt",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(bolt.Hook),
Config: hc.Storage.Bolt,
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageRedis(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Redis: &redis.Options{
Username: "test",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(redis.Hook),
Config: hc.Storage.Redis,
},
}
require.Equal(t, expect, th)
}

View File

@@ -63,7 +63,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -45,7 +45,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -32,7 +32,10 @@ func main() {
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: *tcpAddr,
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -0,0 +1,92 @@
{
"listeners": [
{
"type": "tcp",
"id": "file-tcp1",
"address": ":1883"
},
{
"type": "ws",
"id": "file-websocket",
"address": ":1882"
},
{
"type": "healthcheck",
"id": "file-healthcheck",
"address": ":1880"
}
],
"hooks": {
"debug": {
"enable": true
},
"storage": {
"badger": {
"path": "badger.db",
"gc_interval": 3,
"gc_discard_ratio": 0.5
},
"bolt": {
"path": "bolt.db"
},
"redis": {
"h_prefix": "mc",
"username": "mochi",
"password": "melon",
"address": "localhost:6379",
"database": 1
}
},
"auth": {
"allow_all": false,
"ledger": {
"auth": [
{
"username": "peach",
"password": "password1",
"allow": true
}
],
"acl": [
{
"remote": "127.0.0.1:*"
},
{
"username": "melon",
"filters": null,
"melon/#": 3,
"updates/#": 2
}
]
}
}
},
"options": {
"client_net_write_buffer_size": 2048,
"client_net_read_buffer_size": 2048,
"sys_topic_resend_interval": 10,
"inline_client": true,
"capabilities": {
"maximum_message_expiry_interval": 100,
"maximum_client_writes_pending": 8192,
"maximum_session_expiry_interval": 86400,
"maximum_packet_size": 0,
"receive_maximum": 1024,
"maximum_inflight": 8192,
"topic_alias_maximum": 65535,
"shared_sub_available": 1,
"minimum_protocol_version": 3,
"maximum_qos": 2,
"retain_available": 1,
"wildcard_sub_available": 1,
"sub_id_available": 1,
"compatibilities": {
"obscure_not_authorized": true,
"passive_client_disconnect": false,
"always_return_response_info": false,
"restore_sys_info_on_restart": false,
"no_inherited_properties_on_ack": false
}
}
}
}

View File

@@ -0,0 +1,64 @@
listeners:
- type: "tcp"
id: "file-tcp1"
address: ":1883"
- type: "ws"
id: "file-websocket"
address: ":1882"
- type: "healthcheck"
id: "file-healthcheck"
address: ":1880"
hooks:
debug:
enable: true
storage:
badger:
path: badger.db
gc_interval: 3
gc_discard_ratio: 0.5
bolt:
path: bolt.db
redis:
h_prefix: "mc"
username: "mochi"
password: "melon"
address: "localhost:6379"
database: 1
auth:
allow_all: false
ledger:
auth:
- username: peach
password: password1
allow: true
acl:
- remote: 127.0.0.1:*
- username: melon
filters:
melon/#: 3
updates/#: 2
options:
client_net_write_buffer_size: 2048
client_net_read_buffer_size: 2048
sys_topic_resend_interval: 10
inline_client: true
capabilities:
maximum_message_expiry_interval: 100
maximum_client_writes_pending: 8192
maximum_session_expiry_interval: 86400
maximum_packet_size: 0
receive_maximum: 1024
maximum_inflight: 8192
topic_alias_maximum: 65535
shared_sub_available: 1
minimum_protocol_version: 3
maximum_qos: 2
retain_available: 1
wildcard_sub_available: 1
sub_id_available: 1
compatibilities:
obscure_not_authorized: true
passive_client_disconnect: false
always_return_response_info: false
restore_sys_info_on_restart: false
no_inherited_properties_on_ack: false

49
examples/config/main.go Normal file
View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"github.com/mochi-mqtt/server/v2/config"
"log"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-mqtt/server/v2"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
configBytes, err := os.ReadFile("config.json")
if err != nil {
log.Fatal(err)
}
options, err := config.FromBytes(configBytes)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(options)
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -46,7 +46,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -33,7 +33,10 @@ func main() {
}) })
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -31,7 +31,10 @@ func main() {
server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
_ = server.AddHook(new(pahoAuthHook), nil) _ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -5,18 +5,16 @@
package main package main
import ( import (
"log"
"os"
"os/signal"
"syscall"
"time"
badgerdb "github.com/dgraph-io/badger" badgerdb "github.com/dgraph-io/badger"
mqtt "github.com/mochi-mqtt/server/v2" mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/hooks/auth" "github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger" "github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/listeners" "github.com/mochi-mqtt/server/v2/listeners"
"github.com/timshannon/badgerhold" "github.com/timshannon/badgerhold"
"log"
"os"
"os/signal"
"syscall"
) )
func main() { func main() {
@@ -41,7 +39,7 @@ func main() {
Path: badgerPath, Path: badgerPath,
// Set the interval for garbage collection. Adjust according to your actual scenario. // Set the interval for garbage collection. Adjust according to your actual scenario.
GcInterval: 5 * time.Minute, GcInterval: 5 * 60,
// GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard. // GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard.
// Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value // Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value
@@ -63,7 +61,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -40,7 +40,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -48,7 +48,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -38,7 +38,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -79,7 +79,9 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{ tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
err = server.AddListener(tcp) err = server.AddListener(tcp)
@@ -87,7 +89,9 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{ ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: ":1882",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
err = server.AddListener(ws) err = server.AddListener(ws)
@@ -95,9 +99,13 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{ stats := listeners.NewHTTPStats(
listeners.Config{
ID: "stats",
Address: ":8080",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}, server.Info) }, server.Info,
)
err = server.AddListener(stats) err = server.AddListener(stats)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -27,7 +27,10 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
ws := listeners.NewWebsocket("ws1", ":1882", nil) ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: ":1882",
})
err := server.AddListener(ws) err := server.AddListener(ws)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -62,6 +62,12 @@ var (
ErrInvalidConfigType = errors.New("invalid config type provided") ErrInvalidConfigType = errors.New("invalid config type provided")
) )
// HookLoadConfig contains the hook and configuration as loaded from a configuration (usually file).
type HookLoadConfig struct {
Hook Hook
Config any
}
// Hook provides an interface of handlers for different events which occur // Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker. // during the lifecycle of the broker.
type Hook interface { type Hook interface {
@@ -70,6 +76,7 @@ type Hook interface {
Init(config any) error Init(config any) error
Stop() error Stop() error
SetOpts(l *slog.Logger, o *HookOptions) SetOpts(l *slog.Logger, o *HookOptions)
OnStarted() OnStarted()
OnStopped() OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool OnConnectAuthenticate(cl *Client, pk packets.Packet) bool

View File

@@ -16,9 +16,10 @@ import (
// Options contains configuration settings for the debug output. // Options contains configuration settings for the debug output.
type Options struct { type Options struct {
ShowPacketData bool // include decoded packet data (default false) Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
ShowPings bool // show ping requests and responses (default false) ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
ShowPasswords bool // show connecting user passwords (default false) ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
} }
// Hook is a debugging hook which logs additional low-level information from the server. // Hook is a debugging hook which logs additional low-level information from the server.

View File

@@ -22,7 +22,7 @@ import (
const ( const (
// defaultDbFile is the default file path for the badger db file. // defaultDbFile is the default file path for the badger db file.
defaultDbFile = ".badger" defaultDbFile = ".badger"
defaultGcInterval = 5 * time.Minute defaultGcInterval = 5 * 60 // gc interval in seconds
defaultGcDiscardRatio = 0.5 defaultGcDiscardRatio = 0.5
) )
@@ -54,16 +54,13 @@ func sysInfoKey() string {
// Options contains configuration settings for the BadgerDB instance. // Options contains configuration settings for the BadgerDB instance.
type Options struct { type Options struct {
Options *badgerhold.Options Options *badgerhold.Options
Path string `yaml:"path" json:"path"`
// The interval for garbage collection.
GcInterval time.Duration
// GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard. // GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard.
// Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value // Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value
// would result in more space reclaims at the cost of increased activity on the LSM tree. // would result in more space reclaims at the cost of increased activity on the LSM tree.
// discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5. // discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5.
GcDiscardRatio float64 GcDiscardRatio float64 `yaml:"gc_discard_ratio" json:"gc_discard_ratio"`
Path string GcInterval int64 `yaml:"gc_interval" json:"gc_interval"`
} }
// Hook is a persistent storage hook based using BadgerDB file store as a backend. // Hook is a persistent storage hook based using BadgerDB file store as a backend.
@@ -151,7 +148,7 @@ func (h *Hook) Init(config any) error {
return err return err
} }
h.gcTicker = time.NewTicker(h.config.GcInterval) h.gcTicker = time.NewTicker(time.Duration(h.config.GcInterval) * time.Second)
go h.GcLoop() go h.GcLoop()
return nil return nil
@@ -224,7 +221,7 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
return return
} }
if cl.StopCause() == packets.ErrSessionTakenOver { if errors.Is(cl.StopCause(), packets.ErrSessionTakenOver) {
return return
} }

View File

@@ -708,7 +708,7 @@ func TestGcLoop(t *testing.T) {
h := new(Hook) h := new(Hook)
h.SetOpts(logger, nil) h.SetOpts(logger, nil)
h.Init(&Options{ h.Init(&Options{
GcInterval: 2 * time.Second, // Set the interval for garbage collection. GcInterval: 2, // Set the interval for garbage collection.
Options: &badgerhold.Options{ Options: &badgerhold.Options{
// BadgerDB options. Modify as needed. // BadgerDB options. Modify as needed.
Options: badgerdb.Options{ Options: badgerdb.Options{

View File

@@ -56,7 +56,7 @@ func sysInfoKey() string {
// Options contains configuration settings for the bolt instance. // Options contains configuration settings for the bolt instance.
type Options struct { type Options struct {
Options *bbolt.Options Options *bbolt.Options
Path string Path string `yaml:"path" json:"path"`
} }
// Hook is a persistent storage hook based using boltdb file store as a backend. // Hook is a persistent storage hook based using boltdb file store as a backend.

View File

@@ -51,7 +51,11 @@ func sysInfoKey() string {
// Options contains configuration settings for the bolt instance. // Options contains configuration settings for the bolt instance.
type Options struct { type Options struct {
HPrefix string Address string `yaml:"address" json:"address"`
Username string `yaml:"username" json:"username"`
Password string `yaml:"password" json:"password"`
Database int `yaml:"database" json:"database"`
HPrefix string `yaml:"h_prefix" json:"h_prefix"`
Options *redis.Options Options *redis.Options
} }
@@ -105,23 +109,31 @@ func (h *Hook) Init(config any) error {
h.ctx = context.Background() h.ctx = context.Background()
if config == nil { if config == nil {
config = &Options{ config = new(Options)
Options: &redis.Options{
Addr: defaultAddr,
},
} }
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &redis.Options{
Addr: defaultAddr,
}
h.config.Options.Addr = h.config.Address
h.config.Options.DB = h.config.Database
h.config.Options.Username = h.config.Username
h.config.Options.Password = h.config.Password
} }
h.config = config.(*Options)
if h.config.HPrefix == "" { if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix h.config.HPrefix = defaultHPrefix
} }
h.Log.Info("connecting to redis service", h.Log.Info(
"connecting to redis service",
"prefix", h.config.HPrefix,
"address", h.config.Options.Addr, "address", h.config.Options.Addr,
"username", h.config.Options.Username, "username", h.config.Options.Username,
"password-len", len(h.config.Options.Password), "password-len", len(h.config.Options.Password),
"db", h.config.Options.DB) "db", h.config.Options.DB,
)
h.db = redis.NewClient(h.config.Options) h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result() _, err := h.db.Ping(context.Background()).Result()

View File

@@ -135,6 +135,29 @@ func TestInitUseDefaults(t *testing.T) {
require.Equal(t, defaultAddr, h.config.Options.Addr) require.Equal(t, defaultAddr, h.config.Options.Addr)
} }
func TestInitUsePassConfig(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, "")
h.SetOpts(logger, nil)
err := h.Init(&Options{
Address: defaultAddr,
Username: "username",
Password: "password",
Database: 2,
})
require.Error(t, err)
h.db.FlushAll(h.ctx)
require.Equal(t, defaultAddr, h.config.Options.Addr)
require.Equal(t, "username", h.config.Options.Username)
require.Equal(t, "password", h.config.Options.Password)
require.Equal(t, 2, h.config.Options.DB)
}
func TestInitBadConfig(t *testing.T) { func TestInitBadConfig(t *testing.T) {
h := new(Hook) h := new(Hook)
h.SetOpts(logger, nil) h.SetOpts(logger, nil)

View File

@@ -13,24 +13,23 @@ import (
"time" "time"
) )
const TypeHealthCheck = "healthcheck"
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. // HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
type HTTPHealthCheck struct { type HTTPHealthCheck struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // the http server listen *http.Server // the http server
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address. // NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address.
func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck { func NewHTTPHealthCheck(config Config) *HTTPHealthCheck {
if config == nil {
config = new(Config)
}
return &HTTPHealthCheck{ return &HTTPHealthCheck{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -14,47 +14,44 @@ import (
) )
func TestNewHTTPHealthCheck(t *testing.T) { func TestNewHTTPHealthCheck(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "healthcheck", l.id) require.Equal(t, basicConfig.ID, l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, basicConfig.Address, l.address)
} }
func TestHTTPHealthCheckID(t *testing.T) { func TestHTTPHealthCheckID(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "healthcheck", l.ID()) require.Equal(t, basicConfig.ID, l.ID())
} }
func TestHTTPHealthCheckAddress(t *testing.T) { func TestHTTPHealthCheckAddress(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, basicConfig.Address, l.Address())
} }
func TestHTTPHealthCheckProtocol(t *testing.T) { func TestHTTPHealthCheckProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "http", l.Protocol()) require.Equal(t, "http", l.Protocol())
} }
func TestHTTPHealthCheckTLSProtocol(t *testing.T) { func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ l := NewHTTPHealthCheck(tlsConfig)
TLSConfig: tlsConfigBasic,
})
_ = l.Init(logger) _ = l.Init(logger)
require.Equal(t, "https", l.Protocol()) require.Equal(t, "https", l.Protocol())
} }
func TestHTTPHealthCheckInit(t *testing.T) { func TestHTTPHealthCheckInit(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, l.listen) require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr) require.Equal(t, basicConfig.Address, l.listen.Addr)
} }
func TestHTTPHealthCheckServeAndClose(t *testing.T) { func TestHTTPHealthCheckServeAndClose(t *testing.T) {
// setup http stats listener // setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -90,7 +87,7 @@ func TestHTTPHealthCheckServeAndClose(t *testing.T) {
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
// setup http stats listener // setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -125,10 +122,7 @@ func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
} }
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ l := NewHTTPHealthCheck(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -17,27 +17,26 @@ import (
"github.com/mochi-mqtt/server/v2/system" "github.com/mochi-mqtt/server/v2/system"
) )
const TypeSysInfo = "sysinfo"
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct { type HTTPStats struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // the http server listen *http.Server // the http server
sysInfo *system.Info // pointers to the server data sysInfo *system.Info // pointers to the server data
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address. // NewHTTPStats initializes and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats { func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats {
if config == nil {
config = new(Config)
}
return &HTTPStats{ return &HTTPStats{
id: id,
address: address,
sysInfo: sysInfo, sysInfo: sysInfo,
id: config.ID,
address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -17,38 +17,35 @@ import (
) )
func TestNewHTTPStats(t *testing.T) { func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestHTTPStatsID(t *testing.T) { func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestHTTPStatsAddress(t *testing.T) { func TestHTTPStatsAddress(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestHTTPStatsProtocol(t *testing.T) { func TestHTTPStatsProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "http", l.Protocol()) require.Equal(t, "http", l.Protocol())
} }
func TestHTTPStatsTLSProtocol(t *testing.T) { func TestHTTPStatsTLSProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, &Config{ l := NewHTTPStats(tlsConfig, nil)
TLSConfig: tlsConfigBasic,
}, nil)
_ = l.Init(logger) _ = l.Init(logger)
require.Equal(t, "https", l.Protocol()) require.Equal(t, "https", l.Protocol())
} }
func TestHTTPStatsInit(t *testing.T) { func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info) sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo) l := NewHTTPStats(basicConfig, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -64,7 +61,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) {
} }
// setup http stats listener // setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo) l := NewHTTPStats(basicConfig, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -109,9 +106,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
Version: "test", Version: "test",
} }
l := NewHTTPStats("t1", testAddr, &Config{ l := NewHTTPStats(tlsConfig, sysInfo)
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -132,7 +127,9 @@ func TestHTTPStatsFailedToServe(t *testing.T) {
} }
// setup http stats listener // setup http stats listener
l := NewHTTPStats("t1", "wrong_addr", nil, sysInfo) config := basicConfig
config.Address = "wrong_addr"
l := NewHTTPStats(config, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -14,8 +14,10 @@ import (
// Config contains configuration values for a listener. // Config contains configuration values for a listener.
type Config struct { type Config struct {
// TLSConfig is a tls.Config configuration to be used with the listener. Type string
// See examples folder for basic and mutual-tls use. ID string
Address string
// TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config TLSConfig *tls.Config
} }

View File

@@ -19,6 +19,9 @@ import (
const testAddr = ":22222" const testAddr = ":22222"
var ( var (
basicConfig = Config{ID: "t1", Address: testAddr}
tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic}
logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
testCertificate = []byte(`-----BEGIN CERTIFICATE----- testCertificate = []byte(`-----BEGIN CERTIFICATE-----
@@ -65,6 +68,7 @@ func init() {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
} }
tlsConfig.TLSConfig = tlsConfigBasic
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {

View File

@@ -12,6 +12,8 @@ import (
"log/slog" "log/slog"
) )
const TypeMock = "mock"
// MockEstablisher is a function signature which can be used in testing. // MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn) error { func MockEstablisher(id string, c net.Conn) error {
return nil return nil

View File

@@ -13,26 +13,24 @@ import (
"log/slog" "log/slog"
) )
const TypeTCP = "tcp"
// TCP is a listener for establishing client connections on basic TCP protocol. // TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1] type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener config Config // configuration values for the listener
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewTCP initialises and returns a new TCP listener, listening on an address. // NewTCP initializes and returns a new TCP listener, listening on an address.
func NewTCP(id, address string, config *Config) *TCP { func NewTCP(config Config) *TCP {
if config == nil {
config = new(Config)
}
return &TCP{ return &TCP{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -14,45 +14,40 @@ import (
) )
func TestNewTCP(t *testing.T) { func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestTCPID(t *testing.T) { func TestTCPID(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestTCPAddress(t *testing.T) { func TestTCPAddress(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestTCPProtocol(t *testing.T) { func TestTCPProtocol(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "tcp", l.Protocol()) require.Equal(t, "tcp", l.Protocol())
} }
func TestTCPProtocolTLS(t *testing.T) { func TestTCPProtocolTLS(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{ l := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
_ = l.Init(logger) _ = l.Init(logger)
defer l.listen.Close() defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol()) require.Equal(t, "tcp", l.Protocol())
} }
func TestTCPInit(t *testing.T) { func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
l.Close(MockCloser) l.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{ l2 := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err = l2.Init(logger) err = l2.Init(logger)
l2.Close(MockCloser) l2.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
@@ -60,7 +55,7 @@ func TestTCPInit(t *testing.T) {
} }
func TestTCPServeAndClose(t *testing.T) { func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -85,9 +80,7 @@ func TestTCPServeAndClose(t *testing.T) {
} }
func TestTCPServeTLSAndClose(t *testing.T) { func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{ l := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -109,7 +102,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
} }
func TestTCPEstablishThenEnd(t *testing.T) { func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -13,21 +13,25 @@ import (
"log/slog" "log/slog"
) )
const TypeUnix = "unix"
// UnixSock is a listener for establishing client connections on basic UnixSock protocol. // UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct { type UnixSock struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener. id string // the internal id of the listener.
address string // the network address to bind to. address string // the network address to bind to.
config Config // configuration values for the listener
listen net.Listener // a net.Listener which will listen for new clients. listen net.Listener // a net.Listener which will listen for new clients.
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once. end uint32 // ensure the close methods are only called once.
} }
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address. // NewUnixSock initializes and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock { func NewUnixSock(config Config) *UnixSock {
return &UnixSock{ return &UnixSock{
id: id, id: config.ID,
address: address, address: config.Address,
config: config,
} }
} }

View File

@@ -15,41 +15,47 @@ import (
const testUnixAddr = "mochi.sock" const testUnixAddr = "mochi.sock"
var (
unixConfig = Config{ID: "t1", Address: testUnixAddr}
)
func TestNewUnixSock(t *testing.T) { func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address) require.Equal(t, testUnixAddr, l.address)
} }
func TestUnixSockID(t *testing.T) { func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestUnixSockAddress(t *testing.T) { func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, testUnixAddr, l.Address()) require.Equal(t, testUnixAddr, l.Address())
} }
func TestUnixSockProtocol(t *testing.T) { func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "unix", l.Protocol()) require.Equal(t, "unix", l.Protocol())
} }
func TestUnixSockInit(t *testing.T) { func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
l.Close(MockCloser) l.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr) t2Config := unixConfig
t2Config.ID = "t2"
l2 := NewUnixSock(t2Config)
err = l2.Init(logger) err = l2.Init(logger)
l2.Close(MockCloser) l2.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
} }
func TestUnixSockServeAndClose(t *testing.T) { func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -74,7 +80,7 @@ func TestUnixSockServeAndClose(t *testing.T) {
} }
func TestUnixSockEstablishThenEnd(t *testing.T) { func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -19,6 +19,8 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
const TypeWS = "ws"
var ( var (
// ErrInvalidMessage indicates that a message payload was not valid. // ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary") ErrInvalidMessage = errors.New("message type not binary")
@@ -29,7 +31,7 @@ type Websocket struct { // [MQTT-4.2.0-1]
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // a http server for serving websocket connections listen *http.Server // a http server for serving websocket connections
log *slog.Logger // server logger log *slog.Logger // server logger
establish EstablishFn // the server's establish connection handler establish EstablishFn // the server's establish connection handler
@@ -37,15 +39,11 @@ type Websocket struct { // [MQTT-4.2.0-1]
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewWebsocket initialises and returns a new Websocket listener, listening on an address. // NewWebsocket initializes and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string, config *Config) *Websocket { func NewWebsocket(config Config) *Websocket {
if config == nil {
config = new(Config)
}
return &Websocket{ return &Websocket{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
Subprotocols: []string{"mqtt"}, Subprotocols: []string{"mqtt"},

View File

@@ -17,35 +17,33 @@ import (
) )
func TestNewWebsocket(t *testing.T) { func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestWebsocketID(t *testing.T) { func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestWebsocketAddress(t *testing.T) { func TestWebsocketAddress(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestWebsocketProtocol(t *testing.T) { func TestWebsocketProtocol(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "ws", l.Protocol()) require.Equal(t, "ws", l.Protocol())
} }
func TestWebsocketProtocolTLS(t *testing.T) { func TestWebsocketProtocolTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{ l := NewWebsocket(tlsConfig)
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol()) require.Equal(t, "wss", l.Protocol())
} }
func TestWebsocketInit(t *testing.T) { func TestWebsocketInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Nil(t, l.listen) require.Nil(t, l.listen)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -53,7 +51,7 @@ func TestWebsocketInit(t *testing.T) {
} }
func TestWebsocketServeAndClose(t *testing.T) { func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(logger) _ = l.Init(logger)
o := make(chan bool) o := make(chan bool)
@@ -74,9 +72,7 @@ func TestWebsocketServeAndClose(t *testing.T) {
} }
func TestWebsocketServeTLSAndClose(t *testing.T) { func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{ l := NewWebsocket(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -96,9 +92,9 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
} }
func TestWebsocketFailedToServe(t *testing.T) { func TestWebsocketFailedToServe(t *testing.T) {
l := NewWebsocket("t1", "wrong_addr", &Config{ config := tlsConfig
TLSConfig: tlsConfigBasic, config.Address = "wrong_addr"
}) l := NewWebsocket(config)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -117,7 +113,7 @@ func TestWebsocketFailedToServe(t *testing.T) {
} }
func TestWebsocketUpgrade(t *testing.T) { func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(logger) _ = l.Init(logger)
e := make(chan bool) e := make(chan bool)
@@ -136,7 +132,7 @@ func TestWebsocketUpgrade(t *testing.T) {
} }
func TestWebsocketConnectionReads(t *testing.T) { func TestWebsocketConnectionReads(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(nil) _ = l.Init(nil)
recv := make(chan []byte) recv := make(chan []byte)

116
server.go
View File

@@ -14,6 +14,7 @@ import (
"runtime" "runtime"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -26,7 +27,7 @@ import (
) )
const ( const (
Version = "2.4.6" // the current server version. Version = "2.6.0" // the current server version.
defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes defaultSysTopicInterval int64 = 1 // the interval between $SYS topic publishes
LocalListener = "local" LocalListener = "local"
InlineClientId = "inline" InlineClientId = "inline"
@@ -39,25 +40,26 @@ var (
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists
ErrConnectionClosed = errors.New("connection not open") // connection is closed ErrConnectionClosed = errors.New("connection not open") // connection is closed
ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default
ErrOptionsUnreadable = errors.New("unable to read options from bytes")
) )
// Capabilities indicates the capabilities and features provided by the server. // Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct { type Capabilities struct {
MaximumMessageExpiryInterval int64 // maximum message expiry if message expiry is 0 or over MaximumMessageExpiryInterval int64 `yaml:"maximum_message_expiry_interval" json:"maximum_message_expiry_interval"` // maximum message expiry if message expiry is 0 or over
MaximumClientWritesPending int32 // maximum number of pending message writes for a client MaximumClientWritesPending int32 `yaml:"maximum_client_writes_pending" json:"maximum_client_writes_pending"` // maximum number of pending message writes for a client
MaximumSessionExpiryInterval uint32 // maximum number of seconds to keep disconnected sessions MaximumSessionExpiryInterval uint32 `yaml:"maximum_session_expiry_interval" json:"maximum_session_expiry_interval"` // maximum number of seconds to keep disconnected sessions
MaximumPacketSize uint32 // maximum packet size, no limit if 0 MaximumPacketSize uint32 `yaml:"maximum_packet_size" json:"maximum_packet_size"` // maximum packet size, no limit if 0
maximumPacketID uint32 // unexported, used for testing only maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16 // maximum number of concurrent qos messages per client ReceiveMaximum uint16 `yaml:"receive_maximum" json:"receive_maximum"` // maximum number of concurrent qos messages per client
MaximumInflight uint16 // maximum number of qos > 0 messages can be stored, 0(=8192)-65535 MaximumInflight uint16 `yaml:"maximum_inflight" json:"maximum_inflight"` // maximum number of qos > 0 messages can be stored, 0(=8192)-65535
TopicAliasMaximum uint16 // maximum topic alias value TopicAliasMaximum uint16 `yaml:"topic_alias_maximum" json:"topic_alias_maximum"` // maximum topic alias value
SharedSubAvailable byte // support of shared subscriptions SharedSubAvailable byte `yaml:"shared_sub_available" json:"shared_sub_available"` // support of shared subscriptions
MinimumProtocolVersion byte // minimum supported mqtt version MinimumProtocolVersion byte `yaml:"minimum_protocol_version" json:"minimum_protocol_version"` // minimum supported mqtt version
Compatibilities Compatibilities Compatibilities Compatibilities `yaml:"compatibilities" json:"compatibilities"` // version compatibilities the server provides
MaximumQos byte // maximum qos value available to clients MaximumQos byte `yaml:"maximum_qos" json:"maximum_qos"` // maximum qos value available to clients
RetainAvailable byte // support of retain messages RetainAvailable byte `yaml:"retain_available" json:"retain_available"` // support of retain messages
WildcardSubAvailable byte // support of wildcard subscriptions WildcardSubAvailable byte `yaml:"wildcard_sub_available" json:"wildcard_sub_available"` // support of wildcard subscriptions
SubIDAvailable byte // support of subscription identifiers SubIDAvailable byte `yaml:"sub_id_available" json:"sub_id_available"` // support of subscription identifiers
} }
// NewDefaultServerCapabilities defines the default features and capabilities provided by the server. // NewDefaultServerCapabilities defines the default features and capabilities provided by the server.
@@ -82,43 +84,49 @@ func NewDefaultServerCapabilities() *Capabilities {
// Compatibilities provides flags for using compatibility modes. // Compatibilities provides flags for using compatibility modes.
type Compatibilities struct { type Compatibilities struct {
ObscureNotAuthorized bool // return unspecified errors instead of not authorized ObscureNotAuthorized bool `yaml:"obscure_not_authorized" json:"obscure_not_authorized"` // return unspecified errors instead of not authorized
PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation) PassiveClientDisconnect bool `yaml:"passive_client_disconnect" json:"passive_client_disconnect"` // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation)
AlwaysReturnResponseInfo bool // always return response info (useful for testing) AlwaysReturnResponseInfo bool `yaml:"always_return_response_info" json:"always_return_response_info"` // always return response info (useful for testing)
RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped RestoreSysInfoOnRestart bool `yaml:"restore_sys_info_on_restart" json:"restore_sys_info_on_restart"` // restore system info from store as if server never stopped
NoInheritedPropertiesOnAck bool // don't allow inherited user properties on ack (paho - spec violation) NoInheritedPropertiesOnAck bool `yaml:"no_inherited_properties_on_ack" json:"no_inherited_properties_on_ack"` // don't allow inherited user properties on ack (paho - spec violation)
} }
// Options contains configurable options for the server. // Options contains configurable options for the server.
type Options struct { type Options struct {
// Listeners specifies any listeners which should be dynamically added on serve. Used when setting listeners by config.
Listeners []listeners.Config `yaml:"listeners" json:"listeners"`
// Hooks specifies any hooks which should be dynamically added on serve. Used when setting hooks by config.
Hooks []HookLoadConfig `yaml:"hooks" json:"hooks"`
// Capabilities defines the server features and behaviour. If you only wish to modify // Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g. // several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities Capabilities *Capabilities `yaml:"capabilities" json:"capabilities"`
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
ClientNetWriteBufferSize int ClientNetWriteBufferSize int `yaml:"client_net_write_buffer_size" json:"client_net_write_buffer_size"`
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
ClientNetReadBufferSize int ClientNetReadBufferSize int `yaml:"client_net_read_buffer_size" json:"client_net_read_buffer_size"`
// Logger specifies a custom configured implementation of zerolog to override // Logger specifies a custom configured implementation of zerolog to override
// the servers default logger configuration. If you wish to change the log level, // the servers default logger configuration. If you wish to change the log level,
// of the default logger, you can do so by setting // of the default logger, you can do so by setting:
// server := mqtt.New(nil) // server := mqtt.New(nil)
// level := new(slog.LevelVar) // level := new(slog.LevelVar)
// server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
// Level: level, // Level: level,
// })) // }))
// level.Set(slog.LevelDebug) // level.Set(slog.LevelDebug)
Logger *slog.Logger Logger *slog.Logger `yaml:"-" json:"-"`
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64 SysTopicResendInterval int64 `yaml:"sys_topic_resend_interval" json:"sys_topic_resend_interval"`
// Enable Inline client to allow direct subscribing and publishing from the parent codebase, // Enable Inline client to allow direct subscribing and publishing from the parent codebase,
// with negligible performance difference (disabled by default to prevent confusion in statistics). // with negligible performance difference (disabled by default to prevent confusion in statistics).
InlineClient bool InlineClient bool `yaml:"inline_client" json:"inline_client"`
} }
// Server is an MQTT broker server. It should be created with server.New() // Server is an MQTT broker server. It should be created with server.New()
@@ -262,6 +270,17 @@ func (s *Server) AddHook(hook Hook, config any) error {
return s.hooks.Add(hook, config) return s.hooks.Add(hook, config)
} }
// AddHooksFromConfig adds hooks to the server which were specified in the hooks config (usually from a config file).
// New built-in hooks should be added to this list.
func (s *Server) AddHooksFromConfig(hooks []HookLoadConfig) error {
for _, h := range hooks {
if err := s.AddHook(h.Hook, h.Config); err != nil {
return err
}
}
return nil
}
// AddListener adds a new network listener to the server, for receiving incoming client connections. // AddListener adds a new network listener to the server, for receiving incoming client connections.
func (s *Server) AddListener(l listeners.Listener) error { func (s *Server) AddListener(l listeners.Listener) error {
if _, ok := s.Listeners.Get(l.ID()); ok { if _, ok := s.Listeners.Get(l.ID()); ok {
@@ -280,12 +299,55 @@ func (s *Server) AddListener(l listeners.Listener) error {
return nil return nil
} }
// AddListenersFromConfig adds listeners to the server which were specified in the listeners config (usually from a config file).
// New built-in listeners should be added to this list.
func (s *Server) AddListenersFromConfig(configs []listeners.Config) error {
for _, conf := range configs {
var l listeners.Listener
switch strings.ToLower(conf.Type) {
case listeners.TypeTCP:
l = listeners.NewTCP(conf)
case listeners.TypeWS:
l = listeners.NewWebsocket(conf)
case listeners.TypeUnix:
l = listeners.NewUnixSock(conf)
case listeners.TypeHealthCheck:
l = listeners.NewHTTPHealthCheck(conf)
case listeners.TypeSysInfo:
l = listeners.NewHTTPStats(conf, s.Info)
case listeners.TypeMock:
l = listeners.NewMockListener(conf.ID, conf.Address)
default:
s.Log.Error("listener type unavailable by config", "listener", conf.Type)
continue
}
if err := s.AddListener(l); err != nil {
return err
}
}
return nil
}
// Serve starts the event loops responsible for establishing client connections // Serve starts the event loops responsible for establishing client connections
// on all attached listeners, publishing the system topics, and starting all hooks. // on all attached listeners, publishing the system topics, and starting all hooks.
func (s *Server) Serve() error { func (s *Server) Serve() error {
s.Log.Info("mochi mqtt starting", "version", Version) s.Log.Info("mochi mqtt starting", "version", Version)
defer s.Log.Info("mochi mqtt server started") defer s.Log.Info("mochi mqtt server started")
if len(s.Options.Listeners) > 0 {
err := s.AddListenersFromConfig(s.Options.Listeners)
if err != nil {
return err
}
}
if len(s.Options.Hooks) > 0 {
err := s.AddHooksFromConfig(s.Options.Hooks)
if err != nil {
return err
}
}
if s.hooks.Provides( if s.hooks.Provides(
StoredClients, StoredClients,
StoredInflightMessages, StoredInflightMessages,

View File

@@ -220,6 +220,34 @@ func TestServerAddListener(t *testing.T) {
require.Equal(t, ErrListenerIDExists, err) require.Equal(t, ErrListenerIDExists, err)
} }
func TestServerAddHooksFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.AddHooksFromConfig(hooks)
require.NoError(t, err)
}
func TestServerAddHooksFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.AddHooksFromConfig(hooks)
require.Error(t, err)
}
func TestServerAddListenerInitFailure(t *testing.T) { func TestServerAddListenerInitFailure(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -232,6 +260,60 @@ func TestServerAddListenerInitFailure(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestServerAddListenersFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: ":1883"},
{Type: listeners.TypeWS, ID: "ws", Address: ":1882"},
{Type: listeners.TypeHealthCheck, ID: "health", Address: ":1881"},
{Type: listeners.TypeSysInfo, ID: "info", Address: ":1880"},
{Type: listeners.TypeUnix, ID: "unix", Address: "mochi.sock"},
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
{Type: "unknown", ID: "unknown"},
}
err := s.AddListenersFromConfig(lc)
require.NoError(t, err)
require.Equal(t, 6, s.Listeners.Len())
tcp, _ := s.Listeners.Get("tcp")
require.Equal(t, "[::]:1883", tcp.Address())
ws, _ := s.Listeners.Get("ws")
require.Equal(t, ":1882", ws.Address())
health, _ := s.Listeners.Get("health")
require.Equal(t, ":1881", health.Address())
info, _ := s.Listeners.Get("info")
require.Equal(t, ":1880", info.Address())
unix, _ := s.Listeners.Get("unix")
require.Equal(t, "mochi.sock", unix.Address())
mock, _ := s.Listeners.Get("mock")
require.Equal(t, "0", mock.Address())
}
func TestServerAddListenersFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.AddListenersFromConfig(lc)
require.Error(t, err)
require.Equal(t, 0, s.Listeners.Len())
}
func TestServerServe(t *testing.T) { func TestServerServe(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -253,6 +335,57 @@ func TestServerServe(t *testing.T) {
require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
} }
func TestServerServeFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
}
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.Serve()
require.NoError(t, err)
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("mock")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
}
func TestServerServeFromConfigListenerError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeFromConfigHookError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeReadStoreFailure(t *testing.T) { func TestServerServeReadStoreFailure(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()