mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-05 08:07:06 +08:00
Compare commits
13 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1ae050939a | ||
![]() |
f4683d27d0 | ||
![]() |
dff2b1db30 | ||
![]() |
9de6b4e427 | ||
![]() |
78c1914270 | ||
![]() |
f71bf5c3d6 | ||
![]() |
53c4a6b09f | ||
![]() |
a02c6bd8df | ||
![]() |
d8f6d63cc8 | ||
![]() |
bef13eec20 | ||
![]() |
27f3c484ad | ||
![]() |
9b5cdb0bcc | ||
![]() |
2b60a11d4a |
25
README.md
25
README.md
@@ -30,7 +30,7 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
|
||||
- Interfaces for Client Authentication and Topic access control.
|
||||
- Bolt persistence and storage interfaces (see examples folder).
|
||||
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
|
||||
- Basic Event Hooks (`OnMessage`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
|
||||
- Basic Event Hooks (`OnMessage`, `onSubscribe`, `onUnsubscribe`, `OnConnect`, `OnDisconnect`, `onProcessMessage`, `OnError`, `OnStorage`).
|
||||
- ARM32 Compatible.
|
||||
|
||||
#### Roadmap
|
||||
@@ -141,7 +141,25 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
|
||||
```go
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnSubscribe
|
||||
`server.Events.OnSubscribe` is called when a client subscribes to a new topic filter.
|
||||
|
||||
```go
|
||||
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnUnsubscribe
|
||||
`server.Events.OnUnsubscribe` is called when a client unsubscribes from a topic filter.
|
||||
|
||||
```go
|
||||
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
|
||||
}
|
||||
```
|
||||
|
||||
@@ -158,7 +176,6 @@ If an error is returned, the packet will not be modified. and the existing packe
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
@@ -175,6 +192,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
|
||||
|
||||
##### OnError
|
||||
`server.Events.OnError` is called when an error is encountered on the server, particularly within the use of a client connection status.
|
||||
|
||||
|
@@ -54,6 +54,16 @@ func main() {
|
||||
fmt.Printf("<< OnDisconnect client disconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
|
||||
// Add OnSubscribe Event Hook
|
||||
server.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
fmt.Printf("<< OnSubscribe client subscribed %s: %s %v\n", cl.ID, filter, qos)
|
||||
}
|
||||
|
||||
// Add OnUnsubscribe Event Hook
|
||||
server.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
fmt.Printf("<< OnUnsubscribe client unsubscribed %s: %s\n", cl.ID, filter)
|
||||
}
|
||||
|
||||
// Add OnMessage Event Hook
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
pkx = pk
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -57,14 +58,30 @@ func main() {
|
||||
|
||||
fmt.Println(aurora.Magenta("Mochi MQTT Server initializing..."), aurora.Cyan("TLS/SSL"))
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
// Optionally, if you want clients to authenticate only with certs issued by your CA,
|
||||
// you might want to use something like this:
|
||||
// certPool := x509.NewCertPool()
|
||||
// _ = certPool.AppendCertsFromPEM(caCertPem)
|
||||
// tlsConfig := &tls.Config{
|
||||
// ClientCAs: certPool,
|
||||
// ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// }
|
||||
|
||||
server := mqtt.NewServer(nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883")
|
||||
err := server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
err = server.AddListener(tcp, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -72,11 +89,8 @@ func main() {
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882")
|
||||
err = server.AddListener(ws, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@@ -84,11 +98,8 @@ func main() {
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080")
|
||||
err = server.AddListener(stats, &listeners.Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLS: &listeners.TLS{
|
||||
Certificate: testCertificate,
|
||||
PrivateKey: testPrivateKey,
|
||||
},
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@@ -11,6 +11,8 @@ type Events struct {
|
||||
OnError // server error.
|
||||
OnConnect // client connected.
|
||||
OnDisconnect // client disconnected.
|
||||
OnSubscribe // topic subscription created.
|
||||
OnUnsubscribe // topic subscription removed.
|
||||
}
|
||||
|
||||
// Packets is an alias for packets.Packet.
|
||||
@@ -18,9 +20,11 @@ type Packet packets.Packet
|
||||
|
||||
// Client contains limited information about a connected client.
|
||||
type Client struct {
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
ID string
|
||||
Remote string
|
||||
Listener string
|
||||
Username []byte
|
||||
CleanSession bool
|
||||
}
|
||||
|
||||
// Clientlike is an interface for Clients and client-like objects that
|
||||
@@ -40,7 +44,7 @@ type Clientlike interface {
|
||||
// be dispatched as if the event hook had not been triggered.
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any futher
|
||||
// The `mqtt.ErrRejectPacket` error can be returned to reject and abandon any further
|
||||
// processing of the packet.
|
||||
type OnProcessMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
@@ -66,3 +70,9 @@ type OnDisconnect func(Client, error)
|
||||
// OnError is called when errors that will not be passed to
|
||||
// OnDisconnect are handled by the server.
|
||||
type OnError func(Client, error)
|
||||
|
||||
// OnSubscribe is called when a new subscription filter for a client is created.
|
||||
type OnSubscribe func(filter string, cl Client, qos byte)
|
||||
|
||||
// OnUnsubscribe is called when an existing subscription filter for a client is removed.
|
||||
type OnUnsubscribe func(filter string, cl Client)
|
||||
|
@@ -200,9 +200,11 @@ func (cl *Client) Info() events.Client {
|
||||
addr = cl.conn.RemoteAddr().String()
|
||||
}
|
||||
return events.Client{
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Listener: cl.Listener,
|
||||
ID: cl.ID,
|
||||
Remote: addr,
|
||||
Username: cl.Username,
|
||||
CleanSession: cl.CleanSession,
|
||||
Listener: cl.Listener,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -21,7 +21,7 @@ func BenchmarkBytesToString(b *testing.B) {
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
@@ -88,8 +88,8 @@ func TestDecodeString(t *testing.T) {
|
||||
shouldFail: ErrOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrOffsetStrInvalidUTF8,
|
||||
},
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestDecodeString(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
@@ -209,7 +209,7 @@ func TestDecodeByte(t *testing.T) {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
@@ -250,12 +250,11 @@ func TestDecodeUint16(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
@@ -295,7 +294,7 @@ func TestDecodeByteBool(t *testing.T) {
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
|
@@ -215,7 +215,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
}
|
||||
|
||||
if pk.PasswordFlag {
|
||||
pk.Password, offset, err = decodeBytes(buf, offset)
|
||||
pk.Password, _, err = decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedPassword)
|
||||
}
|
||||
@@ -287,7 +287,7 @@ func (pk *Packet) ConnackDecode(buf []byte) error {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedSessionPresent)
|
||||
}
|
||||
|
||||
pk.ReturnCode, offset, err = decodeByte(buf, offset)
|
||||
pk.ReturnCode, _, err = decodeByte(buf, offset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", err, ErrMalformedReturnCode)
|
||||
}
|
||||
|
@@ -70,6 +70,9 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -79,6 +82,8 @@ func (l *HTTPStats) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -73,6 +73,18 @@ func TestHTTPStatsListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLSConfig(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(new(system.Info))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestHTTPStatsListenTLS(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@@ -10,11 +11,25 @@ import (
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
Auth auth.Controller // an authentication controller containing auth and ACL logic.
|
||||
TLS *TLS // the TLS certficates and settings for the connection.
|
||||
// Auth controller containing auth and ACL logic for
|
||||
// allowing or denying access to the server and topics.
|
||||
Auth auth.Controller
|
||||
|
||||
// TLS certficates and settings for the connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
TLS *TLS
|
||||
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// TLS contains the TLS certificates and settings for the listener connection.
|
||||
//
|
||||
// Deprecated: Prefer exposing the tls.Config directly for greater flexibility.
|
||||
// Please use TLSConfig instead.
|
||||
type TLS struct {
|
||||
Certificate []byte // the body of a public certificate.
|
||||
PrivateKey []byte // the body of a private key.
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -37,8 +39,22 @@ WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New(nil)
|
||||
require.NotNil(t, l.internal)
|
||||
|
@@ -62,6 +62,9 @@ func (l *TCP) ID() string {
|
||||
func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
@@ -72,9 +75,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen(l.protocol, l.address)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -73,6 +73,17 @@ func TestTCPListen(t *testing.T) {
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLSConfig(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestTCPListenTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -122,6 +122,9 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// The following logic is deprecated in favour of passing through the tls.Config
|
||||
// value directly, however it remains in order to provide backwards compatibility.
|
||||
// It will be removed someday, so use the preferred method (l.config.TLSConfig).
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
@@ -131,6 +134,8 @@ func (l *Websocket) Listen(s *system.Info) error {
|
||||
l.listen.TLSConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
} else {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -77,6 +77,18 @@ func TestWebsocketListen(t *testing.T) {
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLSConfig(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
Auth: new(auth.Allow),
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Listen(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen.TLSConfig)
|
||||
l.listen.Close()
|
||||
}
|
||||
|
||||
func TestWebsocketListenTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testPort)
|
||||
l.SetConfig(&Config{
|
||||
|
@@ -400,6 +400,9 @@ func (s *Server) unsubscribeClient(cl *clients.Client) {
|
||||
for k := range cl.Subscriptions {
|
||||
delete(cl.Subscriptions, k)
|
||||
if s.Topics.Unsubscribe(k, cl.ID) {
|
||||
if s.Events.OnUnsubscribe != nil {
|
||||
s.Events.OnUnsubscribe(k, cl.Info())
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
}
|
||||
@@ -735,8 +738,11 @@ func (s *Server) processSubscribe(cl *clients.Client, pk packets.Packet) error {
|
||||
if !cl.AC.ACL(cl.Username, pk.Topics[i], false) {
|
||||
retCodes[i] = packets.ErrSubAckNetworkError
|
||||
} else {
|
||||
q := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
|
||||
if q {
|
||||
r := s.Topics.Subscribe(pk.Topics[i], cl.ID, pk.Qoss[i])
|
||||
if r {
|
||||
if s.Events.OnSubscribe != nil {
|
||||
s.Events.OnSubscribe(pk.Topics[i], cl.Info(), pk.Qoss[i])
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, 1)
|
||||
}
|
||||
cl.NoteSubscription(pk.Topics[i], pk.Qoss[i])
|
||||
@@ -785,6 +791,9 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) error
|
||||
for i := 0; i < len(pk.Topics); i++ {
|
||||
q := s.Topics.Unsubscribe(pk.Topics[i], cl.ID)
|
||||
if q {
|
||||
if s.Events.OnUnsubscribe != nil {
|
||||
s.Events.OnUnsubscribe(pk.Topics[i], cl.Info())
|
||||
}
|
||||
atomic.AddInt64(&s.System.Subscriptions, -1)
|
||||
}
|
||||
cl.ForgetSubscription(pk.Topics[i])
|
||||
@@ -1004,9 +1013,13 @@ func (s *Server) loadServerInfo(v persistence.ServerInfo) {
|
||||
// loadSubscriptions restores subscriptions from the datastore.
|
||||
func (s *Server) loadSubscriptions(v []persistence.Subscription) {
|
||||
for _, sub := range v {
|
||||
s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS)
|
||||
if cl, ok := s.Clients.Get(sub.Client); ok {
|
||||
cl.NoteSubscription(sub.Filter, sub.QoS)
|
||||
if s.Topics.Subscribe(sub.Filter, sub.Client, sub.QoS) {
|
||||
if cl, ok := s.Clients.Get(sub.Client); ok {
|
||||
cl.NoteSubscription(sub.Filter, sub.QoS)
|
||||
if s.Events.OnSubscribe != nil {
|
||||
s.Events.OnSubscribe(sub.Filter, cl.Info(), sub.QoS)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -368,9 +368,10 @@ func TestServerEventOnConnect(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.Equal(t, events.Packet(packets.Packet{
|
||||
@@ -443,9 +444,10 @@ func TestServerEventOnDisconnect(t *testing.T) {
|
||||
w.Close()
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
require.ErrorIs(t, ErrClientDisconnect, hook.err)
|
||||
@@ -506,9 +508,10 @@ func TestServerEventOnDisconnectOnError(t *testing.T) {
|
||||
require.Equal(t, errx, hook.err)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
ID: "mochi",
|
||||
Remote: "pipe",
|
||||
Listener: "tcp",
|
||||
CleanSession: true,
|
||||
}, hook.client)
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
@@ -1946,6 +1949,15 @@ func TestServerProcessSubscribeInvalid(t *testing.T) {
|
||||
func TestServerProcessSubscribe(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
|
||||
subscribeEvent := ""
|
||||
subscribeClient := ""
|
||||
s.Events.OnSubscribe = func(filter string, cl events.Client, qos byte) {
|
||||
if filter == "a/b/c" {
|
||||
subscribeEvent = "a/b/c"
|
||||
subscribeClient = cl.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.Topics.RetainMessage(packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
@@ -1995,6 +2007,8 @@ func TestServerProcessSubscribe(t *testing.T) {
|
||||
require.Equal(t, byte(1), cl.Subscriptions["d/e/f"])
|
||||
require.Equal(t, topics.Subscriptions{cl.ID: 0}, s.Topics.Subscribers("a/b/c"))
|
||||
require.Equal(t, topics.Subscriptions{cl.ID: 1}, s.Topics.Subscribers("d/e/f"))
|
||||
require.Equal(t, "a/b/c", subscribeEvent)
|
||||
require.Equal(t, cl.ID, subscribeClient)
|
||||
}
|
||||
|
||||
func TestServerProcessSubscribeFailACL(t *testing.T) {
|
||||
@@ -2114,6 +2128,16 @@ func TestServerProcessUnsubscribeInvalid(t *testing.T) {
|
||||
|
||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
|
||||
unsubscribeEvent := ""
|
||||
unsubscribeClient := ""
|
||||
s.Events.OnUnsubscribe = func(filter string, cl events.Client) {
|
||||
if filter == "a/b/c" {
|
||||
unsubscribeEvent = "a/b/c"
|
||||
unsubscribeClient = cl.ID
|
||||
}
|
||||
}
|
||||
|
||||
s.Clients.Add(cl)
|
||||
s.Topics.Subscribe("a/b/c", cl.ID, 0)
|
||||
s.Topics.Subscribe("d/e/f", cl.ID, 1)
|
||||
@@ -2155,6 +2179,9 @@ func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
|
||||
require.NotEmpty(t, s.Topics.Subscribers("a/b/+"))
|
||||
require.Contains(t, cl.Subscriptions, "a/b/+")
|
||||
|
||||
require.Equal(t, "a/b/c", unsubscribeEvent)
|
||||
require.Equal(t, cl.ID, unsubscribeClient)
|
||||
}
|
||||
|
||||
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user