mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-01 22:42:14 +08:00
Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
51d6825430 | ||
![]() |
514288c53e | ||
![]() |
957fc0a049 | ||
![]() |
03f94f948a | ||
![]() |
1bc752a2b8 | ||
![]() |
b9db59ba12 | ||
![]() |
c0ef58c363 | ||
![]() |
994adea3b4 | ||
![]() |
fc61cc9be5 | ||
![]() |
22d7338878 | ||
![]() |
3f28515706 | ||
![]() |
7d73ce9caf | ||
![]() |
0758bc961c | ||
![]() |
8472b9ae8a | ||
![]() |
530a018e80 | ||
![]() |
0b594afb4e | ||
![]() |
9d0ea957bb | ||
![]() |
8067785ac4 | ||
![]() |
6ffc8a8388 |
@@ -124,6 +124,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
}
|
||||
```
|
||||
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
|
||||
#### Direct Publishing
|
||||
|
@@ -54,6 +54,12 @@ func main() {
|
||||
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
|
||||
}
|
||||
|
||||
// Example of using AllowClients to selectively deliver/drop messages.
|
||||
// Only a client with the id of `allowed-client` will received messages on the topic.
|
||||
if pkx.TopicName == "a/b/restricted" {
|
||||
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
|
@@ -99,7 +99,7 @@ type Client struct {
|
||||
packetID uint32 // the current highest packetID.
|
||||
LWT LWT // the last will and testament for the client.
|
||||
State State // the operational state of the client.
|
||||
system *system.Info // pointers to server system info.
|
||||
systemInfo *system.Info // pointers to server system info.
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
@@ -114,11 +114,11 @@ type State struct {
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
|
||||
cl := &Client{
|
||||
conn: c,
|
||||
r: r,
|
||||
w: w,
|
||||
system: s,
|
||||
keepalive: defaultKeepalive,
|
||||
conn: c,
|
||||
r: r,
|
||||
w: w,
|
||||
systemInfo: s,
|
||||
keepalive: defaultKeepalive,
|
||||
Inflight: Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
|
||||
// Having successfully read n bytes, commit the tail forward.
|
||||
cl.r.CommitTail(n)
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -334,7 +334,7 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.system.MessagesRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
|
||||
|
||||
pk.FixedHeader = *fh
|
||||
if pk.FixedHeader.Remaining == 0 {
|
||||
@@ -345,7 +345,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
@@ -359,7 +359,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
@@ -407,7 +407,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishSent, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
@@ -443,8 +443,9 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.system.MessagesSent, 1)
|
||||
|
||||
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
|
||||
|
@@ -340,7 +340,7 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
|
||||
tail, head := cl.r.GetPos()
|
||||
require.Equal(t, int64(2), tail)
|
||||
@@ -456,8 +456,8 @@ func TestClientReadOK(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
|
||||
|
||||
}
|
||||
|
||||
@@ -574,7 +574,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -647,10 +647,10 @@ func TestClientWritePacket(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
cl.Stop()
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -109,6 +109,10 @@ type Packet struct {
|
||||
Topics []string
|
||||
Qoss []byte
|
||||
|
||||
// If AllowClients set, only deliver to clients in the client allow list.
|
||||
// For use with the OnMessage event hook.
|
||||
AllowClients []string
|
||||
|
||||
ReturnCodes []byte // Suback
|
||||
}
|
||||
|
||||
|
14
server/internal/utils/utils.go
Normal file
14
server/internal/utils/utils.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
// InSliceString returns true if a string exists in a slice of strings.
|
||||
// This temporary and should be replaced with a function from the new
|
||||
// go slices package in 1.19 when available.
|
||||
// https://github.com/golang/go/issues/45955
|
||||
func InSliceString(sl []string, st string) bool {
|
||||
for _, v := range sl {
|
||||
if st == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
18
server/internal/utils/utils_test.go
Normal file
18
server/internal/utils/utils_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInSliceString(t *testing.T) {
|
||||
sl := []string{"a", "b", "c"}
|
||||
require.Equal(t, true, InSliceString(sl, "b"))
|
||||
|
||||
sl = []string{"a", "a", "a"}
|
||||
require.Equal(t, true, InSliceString(sl, "a"))
|
||||
|
||||
sl = []string{"a", "b", "c"}
|
||||
require.Equal(t, false, InSliceString(sl, "d"))
|
||||
}
|
@@ -44,11 +44,8 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
for {
|
||||
select {
|
||||
case <-l.done:
|
||||
return
|
||||
}
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
|
@@ -17,12 +17,13 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessage = errors.New("Message type not binary")
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
|
||||
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
|
||||
// websocket compliant connection.
|
||||
wsUpgrader = &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
|
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mochi-co/mqtt/server/internal/clients"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/internal/topics"
|
||||
"github.com/mochi-co/mqtt/server/internal/utils"
|
||||
"github.com/mochi-co/mqtt/server/listeners"
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/persistence"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.2" // the server version.
|
||||
Version = "1.0.3" // the server version.
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -449,6 +450,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
|
||||
func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
|
||||
if client, ok := s.Clients.Get(id); ok {
|
||||
|
||||
// If the AllowClients value is set, only deliver the packet if the subscribed
|
||||
// client exists in the AllowClients value. For use with the OnMessage event hook
|
||||
// in cases where you want to publish messages to clients selectively.
|
||||
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
|
||||
continue
|
||||
}
|
||||
|
||||
out := pk.PublishCopy()
|
||||
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
|
||||
out.FixedHeader.Qos = qos
|
||||
|
@@ -36,6 +36,15 @@ func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func setupServerClient(s *Server) (cl *clients.Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
|
||||
cl.ID = "mochi"
|
||||
cl.AC = new(auth.Allow)
|
||||
cl.Start()
|
||||
return
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
s := New()
|
||||
require.NotNil(t, s)
|
||||
@@ -558,7 +567,7 @@ func TestServerProcessPublishQoS1Retain(t *testing.T) {
|
||||
cl1.ID = "mochi1"
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
_, cl2, r2, w2 := setupClient()
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
@@ -687,7 +696,7 @@ func TestServerProcessPublishOfflineQueuing(t *testing.T) {
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
// Start and stop the receiver client
|
||||
_, cl2, _, _ := setupClient()
|
||||
cl2, _, _ := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
s.Topics.Subscribe("qos0", cl2.ID, 0)
|
||||
@@ -1071,6 +1080,86 @@ func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) {
|
||||
require.Equal(t, int64(14), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishHookOnMessageAllowClients(t *testing.T) {
|
||||
s, cl1, r1, w1 := setupClient()
|
||||
cl1.ID = "allowed"
|
||||
s.Clients.Add(cl1)
|
||||
s.Topics.Subscribe("a/b/c", cl1.ID, 0)
|
||||
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "not_allowed"
|
||||
s.Clients.Add(cl2)
|
||||
s.Topics.Subscribe("a/b/c", cl2.ID, 0)
|
||||
s.Topics.Subscribe("d/e/f", cl2.ID, 0)
|
||||
|
||||
s.Events.OnMessage = func(cl events.Client, pk events.Packet) (events.Packet, error) {
|
||||
hookedPacket := pk
|
||||
if pk.TopicName == "a/b/c" {
|
||||
hookedPacket.AllowClients = []string{"allowed"}
|
||||
}
|
||||
return hookedPacket, nil
|
||||
}
|
||||
|
||||
ack1 := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(r1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ack1 <- buf
|
||||
}()
|
||||
|
||||
ack2 := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(r2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ack2 <- buf
|
||||
}()
|
||||
|
||||
pk1 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
err := s.processPacket(cl1, pk1)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk2 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("a"),
|
||||
}
|
||||
err = s.processPacket(cl1, pk2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
w1.Close()
|
||||
w2.Close()
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Publish << 4), 12,
|
||||
0, 5,
|
||||
'a', '/', 'b', '/', 'c',
|
||||
'h', 'e', 'l', 'l', 'o',
|
||||
}, <-ack1)
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Publish << 4), 8,
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'a',
|
||||
}, <-ack2)
|
||||
|
||||
require.Equal(t, int64(24), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPuback(t *testing.T) {
|
||||
s, cl, _, _ := setupClient()
|
||||
cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})
|
||||
@@ -1443,7 +1532,7 @@ func TestServerCloseClientLWT(t *testing.T) {
|
||||
}
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
_, cl2, r2, w2 := setupClient()
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
@@ -1566,14 +1655,14 @@ func TestServerLoadSubscriptions(t *testing.T) {
|
||||
s.Clients.Add(cl)
|
||||
|
||||
subs := []persistence.Subscription{
|
||||
persistence.Subscription{
|
||||
{
|
||||
ID: "test:a/b/c",
|
||||
Client: "test",
|
||||
Filter: "a/b/c",
|
||||
QoS: 1,
|
||||
T: persistence.KSubscription,
|
||||
},
|
||||
persistence.Subscription{
|
||||
{
|
||||
ID: "test:d/e/f",
|
||||
Client: "test",
|
||||
Filter: "d/e/f",
|
||||
@@ -1623,7 +1712,7 @@ func TestServerLoadInflight(t *testing.T) {
|
||||
require.NotNil(t, s)
|
||||
|
||||
msgs := []persistence.Message{
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_if_0",
|
||||
T: persistence.KInflight,
|
||||
Client: "client1",
|
||||
@@ -1633,7 +1722,7 @@ func TestServerLoadInflight(t *testing.T) {
|
||||
Sent: 100,
|
||||
Resends: 0,
|
||||
},
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_if_100",
|
||||
T: persistence.KInflight,
|
||||
Client: "client1",
|
||||
@@ -1668,7 +1757,7 @@ func TestServerLoadRetained(t *testing.T) {
|
||||
require.NotNil(t, s)
|
||||
|
||||
msgs := []persistence.Message{
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_ret_200",
|
||||
T: persistence.KRetained,
|
||||
FixedHeader: persistence.FixedHeader{
|
||||
@@ -1680,7 +1769,7 @@ func TestServerLoadRetained(t *testing.T) {
|
||||
Sent: 100,
|
||||
Resends: 0,
|
||||
},
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_ret_300",
|
||||
T: persistence.KRetained,
|
||||
FixedHeader: persistence.FixedHeader{
|
||||
|
Reference in New Issue
Block a user