Compare commits

..

30 Commits

Author SHA1 Message Date
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
14 changed files with 440 additions and 45 deletions

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cmd/mqtt
.DS_Store
.DS_Store
server/persistence/bolt/testbolt.db

View File

@@ -25,9 +25,10 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (currently `onMessage`)
- Basic Event Hooks (currently `OnMessage`, `OnConnect`, `OnDisconnect`)
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
@@ -105,8 +106,30 @@ err := server.AddListener(tcp, &listeners.Config{
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
`server.Events.OnMessage` is called when a Publish packet is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
@@ -124,7 +147,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
}
```
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.

View File

@@ -44,6 +44,16 @@ func main() {
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
@@ -54,6 +64,12 @@ func main() {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}

View File

@@ -6,7 +6,9 @@ import (
)
type Events struct {
OnMessage // published message receieved.
OnMessage // published message receieved.
OnConnect // client connected.
OnDisconnect // client disconnected.
}
type Packet packets.Packet
@@ -17,7 +19,7 @@ type Client struct {
}
// FromClient returns an event client from a client.
func FromClient(cl clients.Client) Client {
func FromClient(cl *clients.Client) Client {
return Client{
ID: cl.ID,
Listener: cl.Listener,
@@ -34,3 +36,11 @@ func FromClient(cl clients.Client) Client {
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)

View File

@@ -99,7 +99,7 @@ type Client struct {
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
systemInfo *system.Info // pointers to server system info.
}
// State tracks the state of the client.
@@ -114,11 +114,11 @@ type State struct {
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
conn: c,
r: r,
w: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
internal: make(map[uint16]InflightMessage),
},
@@ -300,7 +300,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
@@ -334,7 +334,7 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
@@ -345,7 +345,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
@@ -359,7 +359,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
@@ -407,7 +407,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
@@ -443,8 +443,9 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)

View File

@@ -340,7 +340,7 @@ func TestClientReadFixedHeader(t *testing.T) {
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
tail, head := cl.r.GetPos()
require.Equal(t, int64(2), tail)
@@ -456,8 +456,8 @@ func TestClientReadOK(t *testing.T) {
},
})
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
}
@@ -574,7 +574,7 @@ func TestClientReadPacket(t *testing.T) {
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
}
}
}
@@ -647,10 +647,10 @@ func TestClientWritePacket(t *testing.T) {
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop()
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
}
}
}

View File

@@ -109,6 +109,10 @@ type Packet struct {
Topics []string
Qoss []byte
// If AllowClients set, only deliver to clients in the client allow list.
// For use with the OnMessage event hook.
AllowClients []string
ReturnCodes []byte // Suback
}

View File

@@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@@ -44,11 +44,8 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for {
select {
case <-l.done:
return
}
for range l.done {
return
}
}

View File

@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
var err error
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})

View File

@@ -17,12 +17,13 @@ import (
)
var (
ErrInvalidMessage = errors.New("Message type not binary")
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)

View File

@@ -14,6 +14,7 @@ import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/internal/utils"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence"
@@ -21,7 +22,7 @@ import (
)
const (
Version = "1.0.2" // the server version.
Version = "1.0.5" // the server version.
)
var (
@@ -258,6 +259,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
})
}
if s.Events.OnConnect != nil {
s.Events.OnConnect(events.FromClient(cl), events.Packet(pk))
}
err = cl.Read(s.processPacket)
if err != nil {
s.closeClient(cl, true)
@@ -269,6 +274,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
atomic.AddInt64(&s.System.ClientsConnected, -1)
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(events.FromClient(cl), err)
}
return err
}
@@ -412,7 +421,7 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// if an OnMessage hook exists, potentially modify the packet.
if s.Events.OnMessage != nil {
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
if pkx, err := s.Events.OnMessage(events.FromClient(cl), events.Packet(pk)); err == nil {
pk = packets.Packet(pkx)
}
}
@@ -449,6 +458,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) publishToSubscribers(pk packets.Packet) {
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
if client, ok := s.Clients.Get(id); ok {
// If the AllowClients value is set, only deliver the packet if the subscribed
// client exists in the AllowClients value. For use with the OnMessage event hook
// in cases where you want to publish messages to clients selectively.
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
continue
}
out := pk.PublishCopy()
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
out.FixedHeader.Qos = qos

View File

@@ -36,6 +36,15 @@ func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
return
}
func setupServerClient(s *Server) (cl *clients.Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
cl.ID = "mochi"
cl.AC = new(auth.Allow)
cl.Start()
return
}
func TestNew(t *testing.T) {
s := New()
require.NotNil(t, s)
@@ -220,6 +229,203 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
w.Close()
}
func TestServerEventOnConnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedPacket events.Packet
s.Events.OnConnect = func(cl events.Client, pk events.Packet) {
hookedClient = cl
hookedPacket = pk
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, events.Packet(packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connect,
Remaining: 17,
},
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
ProtocolVersion: 4,
CleanSession: true,
Keepalive: 45,
ClientIdentifier: "mochi",
}), hookedPacket)
}
func TestServerEventOnDisconnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, nil, hookedErr)
}
func TestServerEventOnDisconnectOnError(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{0, 0})
}()
// Receive the Connack
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.Error(t, errx)
require.Equal(t, "No valid packet available; 0", errx.Error())
fmt.Println(hookedErr)
require.Equal(t, errx, hookedErr)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
}
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
s := New()
@@ -433,6 +639,10 @@ func TestServerEstablishConnectionReadPacketErr(t *testing.T) {
require.Error(t, errx)
}
func TestServerOnDisconnectErr(t *testing.T) {
}
func TestServerWriteClient(t *testing.T) {
s, cl, r, w := setupClient()
cl.ID = "mochi"
@@ -558,7 +768,7 @@ func TestServerProcessPublishQoS1Retain(t *testing.T) {
cl1.ID = "mochi1"
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -687,7 +897,7 @@ func TestServerProcessPublishOfflineQueuing(t *testing.T) {
s.Clients.Add(cl1)
// Start and stop the receiver client
_, cl2, _, _ := setupClient()
cl2, _, _ := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
s.Topics.Subscribe("qos0", cl2.ID, 0)
@@ -922,7 +1132,7 @@ func TestServerPublishInlineSysTopicError(t *testing.T) {
require.Equal(t, int64(0), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessage(t *testing.T) {
func TestServerEventOnMessage(t *testing.T) {
s, cl1, r1, w1 := setupClient()
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/+", cl1.ID, 0)
@@ -1071,6 +1281,86 @@ func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) {
require.Equal(t, int64(14), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessageAllowClients(t *testing.T) {
s, cl1, r1, w1 := setupClient()
cl1.ID = "allowed"
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/c", cl1.ID, 0)
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "not_allowed"
s.Clients.Add(cl2)
s.Topics.Subscribe("a/b/c", cl2.ID, 0)
s.Topics.Subscribe("d/e/f", cl2.ID, 0)
s.Events.OnMessage = func(cl events.Client, pk events.Packet) (events.Packet, error) {
hookedPacket := pk
if pk.TopicName == "a/b/c" {
hookedPacket.AllowClients = []string{"allowed"}
}
return hookedPacket, nil
}
ack1 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r1)
if err != nil {
panic(err)
}
ack1 <- buf
}()
ack2 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r2)
if err != nil {
panic(err)
}
ack2 <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
}
err := s.processPacket(cl1, pk1)
require.NoError(t, err)
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "d/e/f",
Payload: []byte("a"),
}
err = s.processPacket(cl1, pk2)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
w1.Close()
w2.Close()
require.Equal(t, []byte{
byte(packets.Publish << 4), 12,
0, 5,
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o',
}, <-ack1)
require.Equal(t, []byte{
byte(packets.Publish << 4), 8,
0, 5,
'd', '/', 'e', '/', 'f',
'a',
}, <-ack2)
require.Equal(t, int64(24), s.System.BytesSent)
}
func TestServerProcessPuback(t *testing.T) {
s, cl, _, _ := setupClient()
cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})
@@ -1443,7 +1733,7 @@ func TestServerCloseClientLWT(t *testing.T) {
}
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -1566,14 +1856,14 @@ func TestServerLoadSubscriptions(t *testing.T) {
s.Clients.Add(cl)
subs := []persistence.Subscription{
persistence.Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
},
persistence.Subscription{
{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
@@ -1623,7 +1913,7 @@ func TestServerLoadInflight(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_if_0",
T: persistence.KInflight,
Client: "client1",
@@ -1633,7 +1923,7 @@ func TestServerLoadInflight(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_if_100",
T: persistence.KInflight,
Client: "client1",
@@ -1668,7 +1958,7 @@ func TestServerLoadRetained(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
@@ -1680,7 +1970,7 @@ func TestServerLoadRetained(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{