Implement file-based configuration

This commit is contained in:
mochi-co
2023-12-25 00:28:49 +00:00
parent 5058333f36
commit ac5863644a
42 changed files with 1072 additions and 193 deletions

View File

@@ -11,21 +11,11 @@ RUN go mod download
COPY . ./ COPY . ./
RUN go build -o /app/mochi ./cmd RUN go build -o /app/mochi ./cmd/docker
FROM alpine FROM alpine
WORKDIR / WORKDIR /
COPY --from=builder /app/mochi . COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ] ENTRYPOINT [ "/mochi" ]

56
cmd/docker/main.go Normal file
View File

@@ -0,0 +1,56 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt
// SPDX-FileContributor: dgduncan, mochi-co
package main
import (
"flag"
"github.com/mochi-mqtt/server/v2/config"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-mqtt/server/v2"
)
func main() {
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, nil))) // set basic logger to ensure logs before configuration are in a consistent format
configFile := flag.String("config", "config.yaml", "path to mochi config yaml or json file")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
configBytes, err := os.ReadFile(*configFile)
if err != nil {
log.Fatal(err)
}
options, err := config.FromBytes(configBytes)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(options)
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("mochi mqtt shutdown complete")
}

View File

@@ -33,19 +33,31 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: *tcpAddr,
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
ws := listeners.NewWebsocket("ws1", *wsAddr, nil) ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: *wsAddr,
})
err = server.AddListener(ws) err = server.AddListener(ws)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info) stats := listeners.NewHTTPStats(
listeners.Config{
ID: "info",
Address: *infoAddr,
},
server.Info,
)
err = server.AddListener(stats) err = server.AddListener(stats)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@@ -61,6 +73,5 @@ func main() {
<-done <-done
server.Log.Warn("caught signal, stopping...") server.Log.Warn("caught signal, stopping...")
_ = server.Close() _ = server.Close()
server.Log.Info("main.go finished") server.Log.Info("mochi mqtt shutdown complete")
} }

15
config.yaml Normal file
View File

@@ -0,0 +1,15 @@
listeners:
- type: "tcp"
id: "tcp1"
address: ":1883"
- type: "ws"
id: "ws1"
address: ":1882"
- type: "sysinfo"
id: "stats"
address: ":1880"
hooks:
auth:
allow_all: true
options:
inline_client: true

135
config/config.go Normal file
View File

@@ -0,0 +1,135 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package config
import (
"encoding/json"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/debug"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
"github.com/mochi-mqtt/server/v2/listeners"
"gopkg.in/yaml.v3"
mqtt "github.com/mochi-mqtt/server/v2"
)
type config struct {
Options mqtt.Options
Listeners []listeners.Config `yaml:"listeners" json:"listeners"`
HookConfigs HookConfigs `yaml:"hooks" json:"hooks"`
}
type HookConfigs struct {
Auth *HookAuthConfig `yaml:"auth" json:"auth"`
Storage *HookStorageConfig `yaml:"storage" json:"storage"`
Debug *debug.Options `yaml:"debug" json:"debug"`
}
type HookAuthConfig struct {
Ledger auth.Ledger `yaml:"ledger" json:"ledger"`
AllowAll bool `yaml:"allow_all" json:"allow_all"`
}
type HookStorageConfig struct {
Badger *badger.Options `yaml:"badger" json:"badger"`
Bolt *bolt.Options `yaml:"bolt" json:"bolt"`
Redis *redis.Options `yaml:"redis" json:"redis"`
}
func (hc HookConfigs) ToHooks() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Auth != nil {
hlc = append(hlc, hc.toHooksAuth()...)
}
if hc.Storage != nil {
hlc = append(hlc, hc.toHooksAuth()...)
}
if hc.Debug != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(debug.Hook),
Config: hc.Debug,
})
}
return hlc
}
func (hc HookConfigs) toHooksAuth() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Auth.AllowAll {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(auth.AllowHook),
})
} else {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(auth.Hook),
Config: &auth.Options{
Ledger: &auth.Ledger{ // avoid copying sync.Locker
Users: hc.Auth.Ledger.Users,
Auth: hc.Auth.Ledger.Auth,
ACL: hc.Auth.Ledger.ACL,
},
},
})
}
return hlc
}
func (hc HookConfigs) toHooksStorage() []mqtt.HookLoadConfig {
var hlc []mqtt.HookLoadConfig
if hc.Storage.Badger != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(badger.Hook),
Config: hc.Storage.Badger,
})
}
if hc.Storage.Bolt != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(bolt.Hook),
Config: hc.Storage.Bolt,
})
}
if hc.Storage.Redis != nil {
hlc = append(hlc, mqtt.HookLoadConfig{
Hook: new(redis.Hook),
Config: hc.Storage.Redis,
})
}
return hlc
}
func FromBytes(b []byte) (*mqtt.Options, error) {
c := new(config)
o := mqtt.Options{}
if len(b) == 0 {
return nil, nil
}
if b[0] == '{' {
err := json.Unmarshal(b, c)
if err != nil {
return nil, err
}
} else {
err := yaml.Unmarshal(b, c)
if err != nil {
return nil, err
}
}
o = c.Options
o.Hooks = c.HookConfigs.ToHooks()
o.Listeners = c.Listeners
return &o, nil
}

212
config/config_test.go Normal file
View File

@@ -0,0 +1,212 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package config
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-mqtt/server/v2/hooks/auth"
"github.com/mochi-mqtt/server/v2/hooks/storage/badger"
"github.com/mochi-mqtt/server/v2/hooks/storage/bolt"
"github.com/mochi-mqtt/server/v2/hooks/storage/redis"
"github.com/mochi-mqtt/server/v2/listeners"
mqtt "github.com/mochi-mqtt/server/v2"
)
var (
yamlBytes = []byte(`
listeners:
- type: "tcp"
id: "file-tcp1"
address: ":1883"
hooks:
auth:
allow_all: true
options:
client_net_write_buffer_size: 2048
capabilities:
minimum_protocol_version: 3
compatibilities:
restore_sys_info_on_restart: true
`)
jsonBytes = []byte(`{
"listeners": [
{
"type": "tcp",
"id": "file-tcp1",
"address": ":1883"
}
],
"hooks": {
"auth": {
"allow_all": true
}
},
"options": {
"client_net_write_buffer_size": 2048,
"capabilities": {
"minimum_protocol_version": 3,
"compatibilities": {
"restore_sys_info_on_restart": true
}
}
}
}
`)
parsedOptions = mqtt.Options{
Listeners: []listeners.Config{
{
Type: listeners.TypeTCP,
ID: "file-tcp1",
Address: ":1883",
},
},
Hooks: []mqtt.HookLoadConfig{
{
Hook: new(auth.AllowHook),
},
},
ClientNetWriteBufferSize: 2048,
Capabilities: &mqtt.Capabilities{
MinimumProtocolVersion: 3,
Compatibilities: mqtt.Compatibilities{
RestoreSysInfoOnRestart: true,
},
},
}
)
func TestFromBytesEmptyL(t *testing.T) {
_, err := FromBytes([]byte{})
require.NoError(t, err)
}
func TestFromBytesYAML(t *testing.T) {
o, err := FromBytes(yamlBytes)
require.NoError(t, err)
require.Equal(t, parsedOptions, *o)
}
func TestFromBytesYAMLError(t *testing.T) {
_, err := FromBytes(append(yamlBytes, 'a'))
require.Error(t, err)
}
func TestFromBytesJSON(t *testing.T) {
o, err := FromBytes(jsonBytes)
require.NoError(t, err)
require.Equal(t, parsedOptions, *o)
}
func TestFromBytesJSONError(t *testing.T) {
_, err := FromBytes(append(jsonBytes, 'a'))
require.Error(t, err)
}
func TestToHooksAuthAllowAll(t *testing.T) {
hc := HookConfigs{
Auth: &HookAuthConfig{
AllowAll: true,
},
}
th := hc.toHooksAuth()
expect := []mqtt.HookLoadConfig{
{Hook: new(auth.AllowHook)},
}
require.Equal(t, expect, th)
}
func TestToHooksAuthAllowLedger(t *testing.T) {
hc := HookConfigs{
Auth: &HookAuthConfig{
Ledger: auth.Ledger{
Auth: auth.AuthRules{
{Username: "peach", Password: "password1", Allow: true},
},
},
},
}
th := hc.toHooksAuth()
expect := []mqtt.HookLoadConfig{
{
Hook: new(auth.Hook),
Config: &auth.Options{
Ledger: &auth.Ledger{ // avoid copying sync.Locker
Auth: auth.AuthRules{
{Username: "peach", Password: "password1", Allow: true},
},
},
},
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageBadger(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Badger: &badger.Options{
Path: "badger",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(badger.Hook),
Config: hc.Storage.Badger,
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageBolt(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Bolt: &bolt.Options{
Path: "bolt",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(bolt.Hook),
Config: hc.Storage.Bolt,
},
}
require.Equal(t, expect, th)
}
func TestToHooksStorageRedis(t *testing.T) {
hc := HookConfigs{
Storage: &HookStorageConfig{
Redis: &redis.Options{
Username: "test",
},
},
}
th := hc.toHooksStorage()
expect := []mqtt.HookLoadConfig{
{
Hook: new(redis.Hook),
Config: hc.Storage.Redis,
},
}
require.Equal(t, expect, th)
}

View File

@@ -63,7 +63,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -45,7 +45,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -32,7 +32,10 @@ func main() {
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: *tcpAddr,
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -0,0 +1,88 @@
{
"listeners": [
{
"type": "tcp",
"id": "file-tcp1",
"address": ":1883"
},
{
"type": "ws",
"id": "file-websocket",
"address": ":1882"
},
{
"type": "healthcheck",
"id": "file-healthcheck",
"address": ":1880"
}
],
"hooks": {
"debug": {
"enable": true
},
"storage": {
"badger": {
"path": "badger.db"
},
"bolt": {
"path": "bolt.db"
},
"redis": {
"h_prefix": "mc",
"username": "mochi",
"password": "melon",
"address": "localhost:6379",
"database": 1
}
},
"auth": {
"ledger": {
"auth": [
{
"username": "peach",
"password": "password1",
"allow": true
}
],
"acl": [
{
"remote": "127.0.0.1:*"
},
{
"username": "melon",
"filters": null,
"melon/#": 3,
"updates/#": 2
}
]
}
}
},
"options": {
"client_net_write_buffer_size": 2048,
"client_net_read_buffer_size": 2048,
"sys_topic_resend_interval": 10,
"inline_client": true,
"capabilities": {
"maximum_message_expiry_interval": 100,
"maximum_client_writes_pending": 8192,
"maximum_session_expiry_interval": 86400,
"maximum_packet_size": 0,
"receive_maximum": 1024,
"topic_alias_maximum": 65535,
"shared_sub_available": 1,
"minimum_protocol_version": 3,
"maximum_qos": 2,
"retain_available": 1,
"wildcard_sub_available": 1,
"sub_id_available": 1,
"compatibilities": {
"obscure_not_authorized": true,
"passive_client_disconnect": false,
"always_return_response_info": false,
"restore_sys_info_on_restart": false,
"no_inherited_properties_on_ack": false
}
}
}
}

View File

@@ -0,0 +1,60 @@
listeners:
- type: "tcp"
id: "file-tcp1"
address: ":1883"
- type: "ws"
id: "file-websocket"
address: ":1882"
- type: "healthcheck"
id: "file-healthcheck"
address: ":1880"
hooks:
debug:
enable: true
storage:
badger:
path: badger.db
bolt:
path: bolt.db
redis:
h_prefix: "mc"
username: "mochi"
password: "melon"
address: "localhost:6379"
database: 1
auth:
ledger:
auth:
- username: peach
password: password1
allow: true
acl:
- remote: 127.0.0.1:*
- username: melon
filters:
melon/#: 3
updates/#: 2
options:
client_net_write_buffer_size: 2048
client_net_read_buffer_size: 2048
sys_topic_resend_interval: 10
inline_client: true
capabilities:
maximum_message_expiry_interval: 100
maximum_client_writes_pending: 8192
maximum_session_expiry_interval: 86400
maximum_packet_size: 0
receive_maximum: 1024
topic_alias_maximum: 65535
shared_sub_available: 1
minimum_protocol_version: 3
maximum_qos: 2
retain_available: 1
wildcard_sub_available: 1
sub_id_available: 1
compatibilities:
obscure_not_authorized: true
passive_client_disconnect: false
always_return_response_info: false
restore_sys_info_on_restart: false
no_inherited_properties_on_ack: false

49
examples/config/main.go Normal file
View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"github.com/mochi-mqtt/server/v2/config"
"log"
"os"
"os/signal"
"syscall"
mqtt "github.com/mochi-mqtt/server/v2"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
configBytes, err := os.ReadFile("config.yaml")
if err != nil {
log.Fatal(err)
}
options, err := config.FromBytes(configBytes)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(options)
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn("caught signal, stopping...")
_ = server.Close()
server.Log.Info("main.go finished")
}

View File

@@ -46,7 +46,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -33,7 +33,10 @@ func main() {
}) })
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -31,7 +31,10 @@ func main() {
server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
_ = server.AddHook(new(pahoAuthHook), nil) _ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err := server.AddListener(tcp) err := server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -38,7 +38,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -40,7 +40,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -48,7 +48,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -38,7 +38,10 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
tcp := listeners.NewTCP("t1", ":1883", nil) tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
})
err = server.AddListener(tcp) err = server.AddListener(tcp)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -79,7 +79,9 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{ tcp := listeners.NewTCP(listeners.Config{
ID: "t1",
Address: ":1883",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
err = server.AddListener(tcp) err = server.AddListener(tcp)
@@ -87,7 +89,9 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{ ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: ":1882",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
err = server.AddListener(ws) err = server.AddListener(ws)
@@ -95,9 +99,13 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{ stats := listeners.NewHTTPStats(
listeners.Config{
ID: "stats",
Address: ":8080",
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}, server.Info) }, server.Info,
)
err = server.AddListener(stats) err = server.AddListener(stats)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -27,7 +27,10 @@ func main() {
server := mqtt.New(nil) server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil) _ = server.AddHook(new(auth.AllowHook), nil)
ws := listeners.NewWebsocket("ws1", ":1882", nil) ws := listeners.NewWebsocket(listeners.Config{
ID: "ws1",
Address: ":1882",
})
err := server.AddListener(ws) err := server.AddListener(ws)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@@ -62,6 +62,11 @@ var (
ErrInvalidConfigType = errors.New("invalid config type provided") ErrInvalidConfigType = errors.New("invalid config type provided")
) )
type HookLoadConfig struct {
Hook Hook
Config any
}
// Hook provides an interface of handlers for different events which occur // Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker. // during the lifecycle of the broker.
type Hook interface { type Hook interface {
@@ -70,6 +75,7 @@ type Hook interface {
Init(config any) error Init(config any) error
Stop() error Stop() error
SetOpts(l *slog.Logger, o *HookOptions) SetOpts(l *slog.Logger, o *HookOptions)
OnStarted() OnStarted()
OnStopped() OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool OnConnectAuthenticate(cl *Client, pk packets.Packet) bool

View File

@@ -16,9 +16,10 @@ import (
// Options contains configuration settings for the debug output. // Options contains configuration settings for the debug output.
type Options struct { type Options struct {
ShowPacketData bool // include decoded packet data (default false) Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config
ShowPings bool // show ping requests and responses (default false) ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false)
ShowPasswords bool // show connecting user passwords (default false) ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false)
ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false)
} }
// Hook is a debugging hook which logs additional low-level information from the server. // Hook is a debugging hook which logs additional low-level information from the server.

View File

@@ -51,7 +51,7 @@ func sysInfoKey() string {
// Options contains configuration settings for the BadgerDB instance. // Options contains configuration settings for the BadgerDB instance.
type Options struct { type Options struct {
Options *badgerhold.Options Options *badgerhold.Options
Path string Path string `yaml:"path" json:"path"`
} }
// Hook is a persistent storage hook based using BadgerDB file store as a backend. // Hook is a persistent storage hook based using BadgerDB file store as a backend.

View File

@@ -56,7 +56,7 @@ func sysInfoKey() string {
// Options contains configuration settings for the bolt instance. // Options contains configuration settings for the bolt instance.
type Options struct { type Options struct {
Options *bbolt.Options Options *bbolt.Options
Path string Path string `yaml:"path" json:"path"`
} }
// Hook is a persistent storage hook based using boltdb file store as a backend. // Hook is a persistent storage hook based using boltdb file store as a backend.

View File

@@ -51,7 +51,11 @@ func sysInfoKey() string {
// Options contains configuration settings for the bolt instance. // Options contains configuration settings for the bolt instance.
type Options struct { type Options struct {
HPrefix string Address string `yaml:"address" json:"address"`
Username string `yaml:"username" json:"username"`
Password string `yaml:"password" json:"password"`
Database int `yaml:"database" json:"database"`
HPrefix string `yaml:"h_prefix" json:"h_prefix"`
Options *redis.Options Options *redis.Options
} }
@@ -105,23 +109,31 @@ func (h *Hook) Init(config any) error {
h.ctx = context.Background() h.ctx = context.Background()
if config == nil { if config == nil {
config = &Options{ config = new(Options)
Options: &redis.Options{
Addr: defaultAddr,
},
} }
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &redis.Options{
Addr: defaultAddr,
}
h.config.Options.Addr = h.config.Address
h.config.Options.DB = h.config.Database
h.config.Options.Username = h.config.Username
h.config.Options.Password = h.config.Password
} }
h.config = config.(*Options)
if h.config.HPrefix == "" { if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix h.config.HPrefix = defaultHPrefix
} }
h.Log.Info("connecting to redis service", h.Log.Info(
"connecting to redis service",
"prefix", h.config.HPrefix,
"address", h.config.Options.Addr, "address", h.config.Options.Addr,
"username", h.config.Options.Username, "username", h.config.Options.Username,
"password-len", len(h.config.Options.Password), "password-len", len(h.config.Options.Password),
"db", h.config.Options.DB) "db", h.config.Options.DB,
)
h.db = redis.NewClient(h.config.Options) h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result() _, err := h.db.Ping(context.Background()).Result()

View File

@@ -135,6 +135,29 @@ func TestInitUseDefaults(t *testing.T) {
require.Equal(t, defaultAddr, h.config.Options.Addr) require.Equal(t, defaultAddr, h.config.Options.Addr)
} }
func TestInitUsePassConfig(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, "")
h.SetOpts(logger, nil)
err := h.Init(&Options{
Address: defaultAddr,
Username: "username",
Password: "password",
Database: 2,
})
require.Error(t, err)
h.db.FlushAll(h.ctx)
require.Equal(t, defaultAddr, h.config.Options.Addr)
require.Equal(t, "username", h.config.Options.Username)
require.Equal(t, "password", h.config.Options.Password)
require.Equal(t, 2, h.config.Options.DB)
}
func TestInitBadConfig(t *testing.T) { func TestInitBadConfig(t *testing.T) {
h := new(Hook) h := new(Hook)
h.SetOpts(logger, nil) h.SetOpts(logger, nil)

View File

@@ -13,24 +13,23 @@ import (
"time" "time"
) )
const TypeHealthCheck = "healthcheck"
// HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. // HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint.
type HTTPHealthCheck struct { type HTTPHealthCheck struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // the http server listen *http.Server // the http server
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewHTTPHealthCheck initialises and returns a new HTTP listener, listening on an address. // NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address.
func NewHTTPHealthCheck(id, address string, config *Config) *HTTPHealthCheck { func NewHTTPHealthCheck(config Config) *HTTPHealthCheck {
if config == nil {
config = new(Config)
}
return &HTTPHealthCheck{ return &HTTPHealthCheck{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -14,47 +14,44 @@ import (
) )
func TestNewHTTPHealthCheck(t *testing.T) { func TestNewHTTPHealthCheck(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "healthcheck", l.id) require.Equal(t, basicConfig.ID, l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, basicConfig.Address, l.address)
} }
func TestHTTPHealthCheckID(t *testing.T) { func TestHTTPHealthCheckID(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "healthcheck", l.ID()) require.Equal(t, basicConfig.ID, l.ID())
} }
func TestHTTPHealthCheckAddress(t *testing.T) { func TestHTTPHealthCheckAddress(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, basicConfig.Address, l.Address())
} }
func TestHTTPHealthCheckProtocol(t *testing.T) { func TestHTTPHealthCheckProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
require.Equal(t, "http", l.Protocol()) require.Equal(t, "http", l.Protocol())
} }
func TestHTTPHealthCheckTLSProtocol(t *testing.T) { func TestHTTPHealthCheckTLSProtocol(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ l := NewHTTPHealthCheck(tlsConfig)
TLSConfig: tlsConfigBasic,
})
_ = l.Init(logger) _ = l.Init(logger)
require.Equal(t, "https", l.Protocol()) require.Equal(t, "https", l.Protocol())
} }
func TestHTTPHealthCheckInit(t *testing.T) { func TestHTTPHealthCheckInit(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, l.listen) require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr) require.Equal(t, basicConfig.Address, l.listen.Addr)
} }
func TestHTTPHealthCheckServeAndClose(t *testing.T) { func TestHTTPHealthCheckServeAndClose(t *testing.T) {
// setup http stats listener // setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -90,7 +87,7 @@ func TestHTTPHealthCheckServeAndClose(t *testing.T) {
func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
// setup http stats listener // setup http stats listener
l := NewHTTPHealthCheck("healthcheck", testAddr, nil) l := NewHTTPHealthCheck(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -125,10 +122,7 @@ func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) {
} }
func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) {
l := NewHTTPHealthCheck("healthcheck", testAddr, &Config{ l := NewHTTPHealthCheck(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -17,27 +17,26 @@ import (
"github.com/mochi-mqtt/server/v2/system" "github.com/mochi-mqtt/server/v2/system"
) )
const TypeSysInfo = "sysinfo"
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct { type HTTPStats struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // the http server listen *http.Server // the http server
sysInfo *system.Info // pointers to the server data sysInfo *system.Info // pointers to the server data
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address. // NewHTTPStats initializes and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats { func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats {
if config == nil {
config = new(Config)
}
return &HTTPStats{ return &HTTPStats{
id: id,
address: address,
sysInfo: sysInfo, sysInfo: sysInfo,
id: config.ID,
address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -17,38 +17,35 @@ import (
) )
func TestNewHTTPStats(t *testing.T) { func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestHTTPStatsID(t *testing.T) { func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestHTTPStatsAddress(t *testing.T) { func TestHTTPStatsAddress(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestHTTPStatsProtocol(t *testing.T) { func TestHTTPStatsProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil) l := NewHTTPStats(basicConfig, nil)
require.Equal(t, "http", l.Protocol()) require.Equal(t, "http", l.Protocol())
} }
func TestHTTPStatsTLSProtocol(t *testing.T) { func TestHTTPStatsTLSProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, &Config{ l := NewHTTPStats(tlsConfig, nil)
TLSConfig: tlsConfigBasic,
}, nil)
_ = l.Init(logger) _ = l.Init(logger)
require.Equal(t, "https", l.Protocol()) require.Equal(t, "https", l.Protocol())
} }
func TestHTTPStatsInit(t *testing.T) { func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info) sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo) l := NewHTTPStats(basicConfig, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -64,7 +61,7 @@ func TestHTTPStatsServeAndClose(t *testing.T) {
} }
// setup http stats listener // setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo) l := NewHTTPStats(basicConfig, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -109,9 +106,7 @@ func TestHTTPStatsServeTLSAndClose(t *testing.T) {
Version: "test", Version: "test",
} }
l := NewHTTPStats("t1", testAddr, &Config{ l := NewHTTPStats(tlsConfig, sysInfo)
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -132,7 +127,9 @@ func TestHTTPStatsFailedToServe(t *testing.T) {
} }
// setup http stats listener // setup http stats listener
l := NewHTTPStats("t1", "wrong_addr", nil, sysInfo) config := basicConfig
config.Address = "wrong_addr"
l := NewHTTPStats(config, sysInfo)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -14,8 +14,10 @@ import (
// Config contains configuration values for a listener. // Config contains configuration values for a listener.
type Config struct { type Config struct {
// TLSConfig is a tls.Config configuration to be used with the listener. Type string
// See examples folder for basic and mutual-tls use. ID string
Address string
// TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config TLSConfig *tls.Config
} }

View File

@@ -19,6 +19,9 @@ import (
const testAddr = ":22222" const testAddr = ":22222"
var ( var (
basicConfig = Config{ID: "t1", Address: testAddr}
tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic}
logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
testCertificate = []byte(`-----BEGIN CERTIFICATE----- testCertificate = []byte(`-----BEGIN CERTIFICATE-----
@@ -65,6 +68,7 @@ func init() {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
} }
tlsConfig.TLSConfig = tlsConfigBasic
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {

View File

@@ -12,6 +12,8 @@ import (
"log/slog" "log/slog"
) )
const TypeMock = "mock"
// MockEstablisher is a function signature which can be used in testing. // MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn) error { func MockEstablisher(id string, c net.Conn) error {
return nil return nil

View File

@@ -13,26 +13,24 @@ import (
"log/slog" "log/slog"
) )
const TypeTCP = "tcp"
// TCP is a listener for establishing client connections on basic TCP protocol. // TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1] type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener config Config // configuration values for the listener
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewTCP initialises and returns a new TCP listener, listening on an address. // NewTCP initializes and returns a new TCP listener, listening on an address.
func NewTCP(id, address string, config *Config) *TCP { func NewTCP(config Config) *TCP {
if config == nil {
config = new(Config)
}
return &TCP{ return &TCP{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
} }
} }

View File

@@ -14,45 +14,40 @@ import (
) )
func TestNewTCP(t *testing.T) { func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestTCPID(t *testing.T) { func TestTCPID(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestTCPAddress(t *testing.T) { func TestTCPAddress(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestTCPProtocol(t *testing.T) { func TestTCPProtocol(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
require.Equal(t, "tcp", l.Protocol()) require.Equal(t, "tcp", l.Protocol())
} }
func TestTCPProtocolTLS(t *testing.T) { func TestTCPProtocolTLS(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{ l := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
_ = l.Init(logger) _ = l.Init(logger)
defer l.listen.Close() defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol()) require.Equal(t, "tcp", l.Protocol())
} }
func TestTCPInit(t *testing.T) { func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
l.Close(MockCloser) l.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{ l2 := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err = l2.Init(logger) err = l2.Init(logger)
l2.Close(MockCloser) l2.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
@@ -60,7 +55,7 @@ func TestTCPInit(t *testing.T) {
} }
func TestTCPServeAndClose(t *testing.T) { func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -85,9 +80,7 @@ func TestTCPServeAndClose(t *testing.T) {
} }
func TestTCPServeTLSAndClose(t *testing.T) { func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{ l := NewTCP(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -109,7 +102,7 @@ func TestTCPServeTLSAndClose(t *testing.T) {
} }
func TestTCPEstablishThenEnd(t *testing.T) { func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil) l := NewTCP(basicConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -13,21 +13,25 @@ import (
"log/slog" "log/slog"
) )
const TypeUnix = "unix"
// UnixSock is a listener for establishing client connections on basic UnixSock protocol. // UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct { type UnixSock struct {
sync.RWMutex sync.RWMutex
id string // the internal id of the listener. id string // the internal id of the listener.
address string // the network address to bind to. address string // the network address to bind to.
config Config // configuration values for the listener
listen net.Listener // a net.Listener which will listen for new clients. listen net.Listener // a net.Listener which will listen for new clients.
log *slog.Logger // server logger log *slog.Logger // server logger
end uint32 // ensure the close methods are only called once. end uint32 // ensure the close methods are only called once.
} }
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address. // NewUnixSock initializes and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock { func NewUnixSock(config Config) *UnixSock {
return &UnixSock{ return &UnixSock{
id: id, id: config.ID,
address: address, address: config.Address,
config: config,
} }
} }

View File

@@ -15,41 +15,47 @@ import (
const testUnixAddr = "mochi.sock" const testUnixAddr = "mochi.sock"
var (
unixConfig = Config{ID: "t1", Address: testUnixAddr}
)
func TestNewUnixSock(t *testing.T) { func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address) require.Equal(t, testUnixAddr, l.address)
} }
func TestUnixSockID(t *testing.T) { func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestUnixSockAddress(t *testing.T) { func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, testUnixAddr, l.Address()) require.Equal(t, testUnixAddr, l.Address())
} }
func TestUnixSockProtocol(t *testing.T) { func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
require.Equal(t, "unix", l.Protocol()) require.Equal(t, "unix", l.Protocol())
} }
func TestUnixSockInit(t *testing.T) { func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
l.Close(MockCloser) l.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr) t2Config := unixConfig
t2Config.ID = "t2"
l2 := NewUnixSock(t2Config)
err = l2.Init(logger) err = l2.Init(logger)
l2.Close(MockCloser) l2.Close(MockCloser)
require.NoError(t, err) require.NoError(t, err)
} }
func TestUnixSockServeAndClose(t *testing.T) { func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -74,7 +80,7 @@ func TestUnixSockServeAndClose(t *testing.T) {
} }
func TestUnixSockEstablishThenEnd(t *testing.T) { func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr) l := NewUnixSock(unixConfig)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -19,6 +19,8 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
const TypeWS = "ws"
var ( var (
// ErrInvalidMessage indicates that a message payload was not valid. // ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary") ErrInvalidMessage = errors.New("message type not binary")
@@ -29,7 +31,7 @@ type Websocket struct { // [MQTT-4.2.0-1]
sync.RWMutex sync.RWMutex
id string // the internal id of the listener id string // the internal id of the listener
address string // the network address to bind to address string // the network address to bind to
config *Config // configuration values for the listener config Config // configuration values for the listener
listen *http.Server // a http server for serving websocket connections listen *http.Server // a http server for serving websocket connections
log *slog.Logger // server logger log *slog.Logger // server logger
establish EstablishFn // the server's establish connection handler establish EstablishFn // the server's establish connection handler
@@ -37,15 +39,11 @@ type Websocket struct { // [MQTT-4.2.0-1]
end uint32 // ensure the close methods are only called once end uint32 // ensure the close methods are only called once
} }
// NewWebsocket initialises and returns a new Websocket listener, listening on an address. // NewWebsocket initializes and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string, config *Config) *Websocket { func NewWebsocket(config Config) *Websocket {
if config == nil {
config = new(Config)
}
return &Websocket{ return &Websocket{
id: id, id: config.ID,
address: address, address: config.Address,
config: config, config: config,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
Subprotocols: []string{"mqtt"}, Subprotocols: []string{"mqtt"},

View File

@@ -17,35 +17,33 @@ import (
) )
func TestNewWebsocket(t *testing.T) { func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "t1", l.id) require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address) require.Equal(t, testAddr, l.address)
} }
func TestWebsocketID(t *testing.T) { func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "t1", l.ID()) require.Equal(t, "t1", l.ID())
} }
func TestWebsocketAddress(t *testing.T) { func TestWebsocketAddress(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, testAddr, l.Address()) require.Equal(t, testAddr, l.Address())
} }
func TestWebsocketProtocol(t *testing.T) { func TestWebsocketProtocol(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Equal(t, "ws", l.Protocol()) require.Equal(t, "ws", l.Protocol())
} }
func TestWebsocketProtocolTLS(t *testing.T) { func TestWebsocketProtocolTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{ l := NewWebsocket(tlsConfig)
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol()) require.Equal(t, "wss", l.Protocol())
} }
func TestWebsocketInit(t *testing.T) { func TestWebsocketInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
require.Nil(t, l.listen) require.Nil(t, l.listen)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -53,7 +51,7 @@ func TestWebsocketInit(t *testing.T) {
} }
func TestWebsocketServeAndClose(t *testing.T) { func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(logger) _ = l.Init(logger)
o := make(chan bool) o := make(chan bool)
@@ -74,9 +72,7 @@ func TestWebsocketServeAndClose(t *testing.T) {
} }
func TestWebsocketServeTLSAndClose(t *testing.T) { func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{ l := NewWebsocket(tlsConfig)
TLSConfig: tlsConfigBasic,
})
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -96,9 +92,9 @@ func TestWebsocketServeTLSAndClose(t *testing.T) {
} }
func TestWebsocketFailedToServe(t *testing.T) { func TestWebsocketFailedToServe(t *testing.T) {
l := NewWebsocket("t1", "wrong_addr", &Config{ config := tlsConfig
TLSConfig: tlsConfigBasic, config.Address = "wrong_addr"
}) l := NewWebsocket(config)
err := l.Init(logger) err := l.Init(logger)
require.NoError(t, err) require.NoError(t, err)
@@ -117,7 +113,7 @@ func TestWebsocketFailedToServe(t *testing.T) {
} }
func TestWebsocketUpgrade(t *testing.T) { func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(logger) _ = l.Init(logger)
e := make(chan bool) e := make(chan bool)
@@ -136,7 +132,7 @@ func TestWebsocketUpgrade(t *testing.T) {
} }
func TestWebsocketConnectionReads(t *testing.T) { func TestWebsocketConnectionReads(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil) l := NewWebsocket(basicConfig)
_ = l.Init(nil) _ = l.Init(nil)
recv := make(chan []byte) recv := make(chan []byte)

105
server.go
View File

@@ -14,6 +14,7 @@ import (
"runtime" "runtime"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -52,65 +53,69 @@ var (
ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists ErrListenerIDExists = errors.New("listener id already exists") // a listener with the same id already exists
ErrConnectionClosed = errors.New("connection not open") // connection is closed ErrConnectionClosed = errors.New("connection not open") // connection is closed
ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default ErrInlineClientNotEnabled = errors.New("please set Options.InlineClient=true to use this feature") // inline client is not enabled by default
ErrOptionsUnreadable = errors.New("unable to read options from bytes")
) )
// Capabilities indicates the capabilities and features provided by the server. // Capabilities indicates the capabilities and features provided by the server.
type Capabilities struct { type Capabilities struct {
MaximumMessageExpiryInterval int64 MaximumMessageExpiryInterval int64 `yaml:"maximum_message_expiry_interval" json:"maximum_message_expiry_interval"`
MaximumClientWritesPending int32 MaximumClientWritesPending int32 `yaml:"maximum_client_writes_pending" json:"maximum_client_writes_pending"`
MaximumSessionExpiryInterval uint32 MaximumSessionExpiryInterval uint32 `yaml:"maximum_session_expiry_interval" json:"maximum_session_expiry_interval"`
MaximumPacketSize uint32 MaximumPacketSize uint32 `yaml:"maximum_packet_size" json:"maximum_packet_size"`
maximumPacketID uint32 // unexported, used for testing only maximumPacketID uint32 // unexported, used for testing only
ReceiveMaximum uint16 ReceiveMaximum uint16 `yaml:"receive_maximum" json:"receive_maximum"`
TopicAliasMaximum uint16 TopicAliasMaximum uint16 `yaml:"topic_alias_maximum" json:"topic_alias_maximum"`
SharedSubAvailable byte SharedSubAvailable byte `yaml:"shared_sub_available" json:"shared_sub_available"`
MinimumProtocolVersion byte MinimumProtocolVersion byte `yaml:"minimum_protocol_version" json:"minimum_protocol_version"`
Compatibilities Compatibilities Compatibilities Compatibilities `yaml:"compatibilities" json:"compatibilities"`
MaximumQos byte MaximumQos byte `yaml:"maximum_qos" json:"maximum_qos"`
RetainAvailable byte RetainAvailable byte `yaml:"retain_available" json:"retain_available"`
WildcardSubAvailable byte WildcardSubAvailable byte `yaml:"wildcard_sub_available" json:"wildcard_sub_available"`
SubIDAvailable byte SubIDAvailable byte `yaml:"sub_id_available" json:"sub_id_available"`
} }
// Compatibilities provides flags for using compatibility modes. // Compatibilities provides flags for using compatibility modes.
type Compatibilities struct { type Compatibilities struct {
ObscureNotAuthorized bool // return unspecified errors instead of not authorized ObscureNotAuthorized bool `yaml:"obscure_not_authorized" json:"obscure_not_authorized"` // return unspecified errors instead of not authorized
PassiveClientDisconnect bool // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation) PassiveClientDisconnect bool `yaml:"passive_client_disconnect" json:"passive_client_disconnect"` // don't disconnect the client forcefully after sending disconnect packet (paho - spec violation)
AlwaysReturnResponseInfo bool // always return response info (useful for testing) AlwaysReturnResponseInfo bool `yaml:"always_return_response_info" json:"always_return_response_info"` // always return response info (useful for testing)
RestoreSysInfoOnRestart bool // restore system info from store as if server never stopped RestoreSysInfoOnRestart bool `yaml:"restore_sys_info_on_restart" json:"restore_sys_info_on_restart"` // restore system info from store as if server never stopped
NoInheritedPropertiesOnAck bool // don't allow inherited user properties on ack (paho - spec violation) NoInheritedPropertiesOnAck bool `yaml:"no_inherited_properties_on_ack" json:"no_inherited_properties_on_ack"` // don't allow inherited user properties on ack (paho - spec violation)
} }
// Options contains configurable options for the server. // Options contains configurable options for the server.
type Options struct { type Options struct {
Listeners []listeners.Config `yaml:"listeners" json:"listeners"`
Hooks []HookLoadConfig `yaml:"hooks" json:"hooks"`
// Capabilities defines the server features and behaviour. If you only wish to modify // Capabilities defines the server features and behaviour. If you only wish to modify
// several of these values, set them explicitly - e.g. // several of these values, set them explicitly - e.g.
// server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 // server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
Capabilities *Capabilities Capabilities *Capabilities `yaml:"capabilities" json:"capabilities"`
// ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer. // ClientNetWriteBufferSize specifies the size of the client *bufio.Writer write buffer.
ClientNetWriteBufferSize int ClientNetWriteBufferSize int `yaml:"client_net_write_buffer_size" json:"client_net_write_buffer_size"`
// ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer. // ClientNetReadBufferSize specifies the size of the client *bufio.Reader read buffer.
ClientNetReadBufferSize int ClientNetReadBufferSize int `yaml:"client_net_read_buffer_size" json:"client_net_read_buffer_size"`
// Logger specifies a custom configured implementation of zerolog to override // Logger specifies a custom configured implementation of zerolog to override
// the servers default logger configuration. If you wish to change the log level, // the servers default logger configuration. If you wish to change the log level,
// of the default logger, you can do so by setting // of the default logger, you can do so by setting:
// server := mqtt.New(nil) // server := mqtt.New(nil)
// level := new(slog.LevelVar) // level := new(slog.LevelVar)
// server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ // server.Slog = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
// Level: level, // Level: level,
// })) // }))
// level.Set(slog.LevelDebug) // level.Set(slog.LevelDebug)
Logger *slog.Logger Logger *slog.Logger `yaml:"-" json:"-"`
// SysTopicResendInterval specifies the interval between $SYS topic updates in seconds. // SysTopicResendInterval specifies the interval between $SYS topic updates in seconds.
SysTopicResendInterval int64 SysTopicResendInterval int64 `yaml:"sys_topic_resend_interval" json:"sys_topic_resend_interval"`
// Enable Inline client to allow direct subscribing and publishing from the parent codebase, // Enable Inline client to allow direct subscribing and publishing from the parent codebase,
// with negligible performance difference (disabled by default to prevent confusion in statistics). // with negligible performance difference (disabled by default to prevent confusion in statistics).
InlineClient bool InlineClient bool `yaml:"inline_client" json:"inline_client"`
} }
// Server is an MQTT broker server. It should be created with server.New() // Server is an MQTT broker server. It should be created with server.New()
@@ -250,6 +255,15 @@ func (s *Server) AddHook(hook Hook, config any) error {
return s.hooks.Add(hook, config) return s.hooks.Add(hook, config)
} }
func (s *Server) AddHooksFromConfig(hooks []HookLoadConfig) error {
for _, h := range hooks {
if err := s.AddHook(h.Hook, h.Config); err != nil {
return err
}
}
return nil
}
// AddListener adds a new network listener to the server, for receiving incoming client connections. // AddListener adds a new network listener to the server, for receiving incoming client connections.
func (s *Server) AddListener(l listeners.Listener) error { func (s *Server) AddListener(l listeners.Listener) error {
if _, ok := s.Listeners.Get(l.ID()); ok { if _, ok := s.Listeners.Get(l.ID()); ok {
@@ -268,12 +282,53 @@ func (s *Server) AddListener(l listeners.Listener) error {
return nil return nil
} }
func (s *Server) AddListenersFromConfig(configs []listeners.Config) error {
for _, conf := range configs {
var l listeners.Listener
switch strings.ToLower(conf.Type) {
case listeners.TypeTCP:
l = listeners.NewTCP(conf)
case listeners.TypeWS:
l = listeners.NewWebsocket(conf)
case listeners.TypeUnix:
l = listeners.NewUnixSock(conf)
case listeners.TypeHealthCheck:
l = listeners.NewHTTPHealthCheck(conf)
case listeners.TypeSysInfo:
l = listeners.NewHTTPStats(conf, s.Info)
case listeners.TypeMock:
l = listeners.NewMockListener(conf.ID, conf.Address)
default:
s.Log.Error("listener type unavailable by config", "listener", conf.Type)
continue
}
if err := s.AddListener(l); err != nil {
return err
}
}
return nil
}
// Serve starts the event loops responsible for establishing client connections // Serve starts the event loops responsible for establishing client connections
// on all attached listeners, publishing the system topics, and starting all hooks. // on all attached listeners, publishing the system topics, and starting all hooks.
func (s *Server) Serve() error { func (s *Server) Serve() error {
s.Log.Info("mochi mqtt starting", "version", Version) s.Log.Info("mochi mqtt starting", "version", Version)
defer s.Log.Info("mochi mqtt server started") defer s.Log.Info("mochi mqtt server started")
if len(s.Options.Listeners) > 0 {
err := s.AddListenersFromConfig(s.Options.Listeners)
if err != nil {
return err
}
}
if len(s.Options.Hooks) > 0 {
err := s.AddHooksFromConfig(s.Options.Hooks)
if err != nil {
return err
}
}
if s.hooks.Provides( if s.hooks.Provides(
StoredClients, StoredClients,
StoredInflightMessages, StoredInflightMessages,

View File

@@ -220,6 +220,34 @@ func TestServerAddListener(t *testing.T) {
require.Equal(t, ErrListenerIDExists, err) require.Equal(t, ErrListenerIDExists, err)
} }
func TestServerAddHooksFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.AddHooksFromConfig(hooks)
require.NoError(t, err)
}
func TestServerAddHooksFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.AddHooksFromConfig(hooks)
require.Error(t, err)
}
func TestServerAddListenerInitFailure(t *testing.T) { func TestServerAddListenerInitFailure(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -232,6 +260,60 @@ func TestServerAddListenerInitFailure(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestServerAddListenersFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: ":1883"},
{Type: listeners.TypeWS, ID: "ws", Address: ":1882"},
{Type: listeners.TypeHealthCheck, ID: "health", Address: ":1881"},
{Type: listeners.TypeSysInfo, ID: "info", Address: ":1880"},
{Type: listeners.TypeUnix, ID: "unix", Address: "mochi.sock"},
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
{Type: "unknown", ID: "unknown"},
}
err := s.AddListenersFromConfig(lc)
require.NoError(t, err)
require.Equal(t, 6, s.Listeners.Len())
tcp, _ := s.Listeners.Get("tcp")
require.Equal(t, "[::]:1883", tcp.Address())
ws, _ := s.Listeners.Get("ws")
require.Equal(t, ":1882", ws.Address())
health, _ := s.Listeners.Get("health")
require.Equal(t, ":1881", health.Address())
info, _ := s.Listeners.Get("info")
require.Equal(t, ":1880", info.Address())
unix, _ := s.Listeners.Get("unix")
require.Equal(t, "mochi.sock", unix.Address())
mock, _ := s.Listeners.Get("mock")
require.Equal(t, "0", mock.Address())
}
func TestServerAddListenersFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.AddListenersFromConfig(lc)
require.Error(t, err)
require.Equal(t, 0, s.Listeners.Len())
}
func TestServerServe(t *testing.T) { func TestServerServe(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()
@@ -253,6 +335,57 @@ func TestServerServe(t *testing.T) {
require.Equal(t, true, listener.(*listeners.MockListener).IsServing()) require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
} }
func TestServerServeFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
}
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.Serve()
require.NoError(t, err)
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("mock")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
}
func TestServerServeFromConfigListenerError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeFromConfigHookError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeReadStoreFailure(t *testing.T) { func TestServerServeReadStoreFailure(t *testing.T) {
s := newServer() s := newServer()
defer s.Close() defer s.Close()