mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-04 15:52:55 +08:00
Compare commits
37 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
86e0a5827e | ||
![]() |
06c399b606 | ||
![]() |
ed117f67a1 | ||
![]() |
880a3299e1 | ||
![]() |
1c408d05be | ||
![]() |
fce495f83e | ||
![]() |
471ca00a64 | ||
![]() |
a2c0749640 | ||
![]() |
37293aeecf | ||
![]() |
7a2d4db6a4 | ||
![]() |
03d2a8bc82 | ||
![]() |
4b51e5c7d1 | ||
![]() |
d15ad682bf | ||
![]() |
130ffcbb53 | ||
![]() |
33cf2f991b | ||
![]() |
a360ea6a6c | ||
![]() |
ae3aa0d3fa | ||
![]() |
811ae0e1be | ||
![]() |
51d6825430 | ||
![]() |
514288c53e | ||
![]() |
957fc0a049 | ||
![]() |
03f94f948a | ||
![]() |
1bc752a2b8 | ||
![]() |
b9db59ba12 | ||
![]() |
c0ef58c363 | ||
![]() |
994adea3b4 | ||
![]() |
fc61cc9be5 | ||
![]() |
22d7338878 | ||
![]() |
3f28515706 | ||
![]() |
7d73ce9caf | ||
![]() |
0758bc961c | ||
![]() |
8472b9ae8a | ||
![]() |
530a018e80 | ||
![]() |
0b594afb4e | ||
![]() |
9d0ea957bb | ||
![]() |
8067785ac4 | ||
![]() |
6ffc8a8388 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
cmd/mqtt
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
server/persistence/bolt/testbolt.db
|
||||
|
31
README.md
31
README.md
@@ -25,9 +25,11 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
|
||||
- Interfaces for Client Authentication and Topic access control.
|
||||
- Bolt-backed persistence and storage interfaces.
|
||||
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
|
||||
- Basic Event Hooks (currently `onMessage`)
|
||||
- Basic Event Hooks (currently `OnMessage`, `OnConnect`, `OnDisconnect`).
|
||||
- ARM32 Compatible.
|
||||
|
||||
#### Roadmap
|
||||
- Please open an issue to request new features or event hooks.
|
||||
- MQTT v5 compatibility?
|
||||
|
||||
#### Using the Broker
|
||||
@@ -105,8 +107,30 @@ err := server.AddListener(tcp, &listeners.Config{
|
||||
#### Event Hooks
|
||||
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
|
||||
|
||||
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
|
||||
##### OnConnect
|
||||
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnDisconnect
|
||||
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
|
||||
|
||||
```go
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnMessage
|
||||
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
|
||||
`server.Events.OnMessage` is called when a Publish packet is received. The method receives the published message and information about the client who published it.
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
@@ -124,7 +148,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
}
|
||||
```
|
||||
|
||||
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
|
||||
#### Direct Publishing
|
||||
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
|
||||
|
@@ -44,6 +44,16 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add OnConnect Event Hook
|
||||
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
|
||||
// Add OnDisconnect Event Hook
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
|
||||
// Add OnMessage Event Hook
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
pkx = pk
|
||||
@@ -54,6 +64,12 @@ func main() {
|
||||
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
|
||||
}
|
||||
|
||||
// Example of using AllowClients to selectively deliver/drop messages.
|
||||
// Only a client with the id of `allowed-client` will received messages on the topic.
|
||||
if pkx.TopicName == "a/b/restricted" {
|
||||
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
|
@@ -6,7 +6,9 @@ import (
|
||||
)
|
||||
|
||||
type Events struct {
|
||||
OnMessage // published message receieved.
|
||||
OnMessage // published message receieved.
|
||||
OnConnect // client connected.
|
||||
OnDisconnect // client disconnected.
|
||||
}
|
||||
|
||||
type Packet packets.Packet
|
||||
@@ -17,7 +19,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// FromClient returns an event client from a client.
|
||||
func FromClient(cl clients.Client) Client {
|
||||
func FromClient(cl *clients.Client) Client {
|
||||
return Client{
|
||||
ID: cl.ID,
|
||||
Listener: cl.Listener,
|
||||
@@ -34,3 +36,11 @@ func FromClient(cl clients.Client) Client {
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
type OnMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
// OnConnect is called when a client successfully connects to the broker.
|
||||
type OnConnect func(Client, Packet)
|
||||
|
||||
// OnDisconnect is called when a client disconnects to the broker. An error value
|
||||
// is passed to the function if the client disconnected abnormally, otherwise it
|
||||
// will be nil on a normal disconnect.
|
||||
type OnDisconnect func(Client, error)
|
||||
|
@@ -24,12 +24,13 @@ type Buffer struct {
|
||||
block int // the size of the R/W block.
|
||||
buf []byte // the bytes buffer.
|
||||
tmp []byte // a temporary buffer.
|
||||
_ int32 // align the next fields to an 8-byte boundary for atomic access.
|
||||
head int64 // the current position in the sequence - a forever increasing index.
|
||||
tail int64 // the committed position in the sequence - a forever increasing index.
|
||||
rcond *sync.Cond // the sync condition for the buffer reader.
|
||||
wcond *sync.Cond // the sync condition for the buffer writer.
|
||||
done int64 // indicates that the buffer is closed.
|
||||
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
done uint32 // indicates that the buffer is closed.
|
||||
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
}
|
||||
|
||||
// NewBuffer returns a new instance of buffer. You should call NewReader or
|
||||
@@ -127,7 +128,7 @@ func (b *Buffer) awaitEmpty(n int) error {
|
||||
// then wait until tail has moved.
|
||||
b.rcond.L.Lock()
|
||||
for !b.checkEmpty(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.rcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -146,7 +147,7 @@ func (b *Buffer) awaitFilled(n int) error {
|
||||
// the forever-incrementing tail and head integers.
|
||||
b.wcond.L.Lock()
|
||||
for !b.checkFilled(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.wcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -196,7 +197,7 @@ func (b *Buffer) CapDelta() int {
|
||||
|
||||
// Stop signals the buffer to stop processing.
|
||||
func (b *Buffer) Stop() {
|
||||
atomic.StoreInt64(&b.done, 1)
|
||||
atomic.StoreUint32(&b.done, 1)
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestAtomicAlignment(t *testing.T) {
|
||||
var b Buffer
|
||||
|
||||
offset := unsafe.Offsetof(b.head)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"head requires 64-bit alignment for atomic: offset %d", offset)
|
||||
|
||||
offset = unsafe.Offsetof(b.tail)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"tail requires 64-bit alignment for atomic: offset %d", offset)
|
||||
}
|
||||
|
||||
func TestGetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
tail, head := buf.GetPos()
|
||||
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
|
||||
o <- buf.awaitFilled(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
|
||||
o <- buf.awaitEmpty(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
|
||||
o <- buf.CommitTail(5)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
|
||||
func TestStop(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.Stop()
|
||||
require.Equal(t, int64(1), buf.done)
|
||||
require.Equal(t, uint32(1), buf.done)
|
||||
}
|
||||
|
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
atomic.StoreInt64(&b.State, 1)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
atomic.StoreUint32(&b.State, 1)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
|
@@ -60,7 +60,7 @@ func TestReadFromWrap(t *testing.T) {
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
go func() {
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -116,7 +116,7 @@ func TestReadEnded(t *testing.T) {
|
||||
o <- err
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
@@ -32,10 +32,10 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||
atomic.StoreInt64(&b.State, 2)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
atomic.StoreUint32(&b.State, 2)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
return total, io.EOF
|
||||
}
|
||||
|
||||
|
@@ -59,7 +59,7 @@ func TestWriteTo(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
@@ -74,7 +74,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
cl.RLock()
|
||||
for _, v := range cl.internal {
|
||||
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
|
||||
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
|
||||
clients = append(clients, v)
|
||||
}
|
||||
}
|
||||
@@ -99,12 +99,12 @@ type Client struct {
|
||||
packetID uint32 // the current highest packetID.
|
||||
LWT LWT // the last will and testament for the client.
|
||||
State State // the operational state of the client.
|
||||
system *system.Info // pointers to server system info.
|
||||
systemInfo *system.Info // pointers to server system info.
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
type State struct {
|
||||
Done int64 // atomic counter which indicates that the client has closed.
|
||||
Done uint32 // atomic counter which indicates that the client has closed.
|
||||
started *sync.WaitGroup // tracks the goroutines which have been started.
|
||||
endedW *sync.WaitGroup // tracks when the writer has ended.
|
||||
endedR *sync.WaitGroup // tracks when the reader has ended.
|
||||
@@ -114,11 +114,11 @@ type State struct {
|
||||
// NewClient returns a new instance of Client.
|
||||
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
|
||||
cl := &Client{
|
||||
conn: c,
|
||||
r: r,
|
||||
w: w,
|
||||
system: s,
|
||||
keepalive: defaultKeepalive,
|
||||
conn: c,
|
||||
r: r,
|
||||
w: w,
|
||||
systemInfo: s,
|
||||
keepalive: defaultKeepalive,
|
||||
Inflight: Inflight{
|
||||
internal: make(map[uint16]InflightMessage),
|
||||
},
|
||||
@@ -240,7 +240,7 @@ func (cl *Client) Start() {
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop() {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ func (cl *Client) Stop() {
|
||||
cl.conn.Close()
|
||||
|
||||
cl.State.endedR.Wait()
|
||||
atomic.StoreInt64(&cl.State.Done, 1)
|
||||
atomic.StoreUint32(&cl.State.Done, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
|
||||
// Having successfully read n bytes, commit the tail forward.
|
||||
cl.r.CommitTail(n)
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -309,7 +309,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
// an error is encountered (or the connection is closed).
|
||||
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
|
||||
for {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.system.MessagesRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
|
||||
|
||||
pk.FixedHeader = *fh
|
||||
if pk.FixedHeader.Remaining == 0 {
|
||||
@@ -345,7 +345,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
|
||||
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
@@ -359,7 +359,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishRecv, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
@@ -391,7 +391,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
@@ -407,7 +407,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.system.PublishSent, 1)
|
||||
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
@@ -443,8 +443,9 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&cl.system.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.system.MessagesSent, 1)
|
||||
|
||||
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
|
||||
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
|
||||
|
||||
cl.refreshDeadline(cl.keepalive)
|
||||
|
||||
|
@@ -316,8 +316,8 @@ func TestClientStart(t *testing.T) {
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.r.State))
|
||||
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.w.State))
|
||||
}
|
||||
|
||||
func BenchmarkClientStart(b *testing.B) {
|
||||
@@ -340,7 +340,7 @@ func TestClientReadFixedHeader(t *testing.T) {
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
|
||||
tail, head := cl.r.GetPos()
|
||||
require.Equal(t, int64(2), tail)
|
||||
@@ -456,8 +456,8 @@ func TestClientReadOK(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
|
||||
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
|
||||
|
||||
}
|
||||
|
||||
@@ -574,7 +574,7 @@ func TestClientReadPacket(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -647,10 +647,10 @@ func TestClientWritePacket(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
|
||||
cl.Stop()
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
|
||||
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
|
||||
if tt.packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -16,7 +16,7 @@ type FixedHeader struct {
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, fh.Remaining)
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// decode extracts the specification bits from the header byte.
|
||||
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(buf *bytes.Buffer, length int) {
|
||||
func encodeLength(buf *bytes.Buffer, length int64) {
|
||||
for {
|
||||
digit := byte(length % 128)
|
||||
length /= 128
|
||||
|
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
tt := []struct {
|
||||
have int
|
||||
have int64
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
|
@@ -109,6 +109,10 @@ type Packet struct {
|
||||
Topics []string
|
||||
Qoss []byte
|
||||
|
||||
// If AllowClients set, only deliver to clients in the client allow list.
|
||||
// For use with the OnMessage event hook.
|
||||
AllowClients []string
|
||||
|
||||
ReturnCodes []byte // Suback
|
||||
}
|
||||
|
||||
|
14
server/internal/utils/utils.go
Normal file
14
server/internal/utils/utils.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
// InSliceString returns true if a string exists in a slice of strings.
|
||||
// This temporary and should be replaced with a function from the new
|
||||
// go slices package in 1.19 when available.
|
||||
// https://github.com/golang/go/issues/45955
|
||||
func InSliceString(sl []string, st string) bool {
|
||||
for _, v := range sl {
|
||||
if st == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
18
server/internal/utils/utils_test.go
Normal file
18
server/internal/utils/utils_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInSliceString(t *testing.T) {
|
||||
sl := []string{"a", "b", "c"}
|
||||
require.Equal(t, true, InSliceString(sl, "b"))
|
||||
|
||||
sl = []string{"a", "a", "a"}
|
||||
require.Equal(t, true, InSliceString(sl, "a"))
|
||||
|
||||
sl = []string{"a", "b", "c"}
|
||||
require.Equal(t, false, InSliceString(sl, "d"))
|
||||
}
|
@@ -22,7 +22,7 @@ type HTTPStats struct {
|
||||
system *system.Info // pointers to the server data.
|
||||
address string // the network address to bind to.
|
||||
listen *http.Server // the http server.
|
||||
end int64 // ensure the close methods are only called once.}
|
||||
end uint32 // ensure the close methods are only called once.}
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
@@ -44,11 +44,8 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
for {
|
||||
select {
|
||||
case <-l.done:
|
||||
return
|
||||
}
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -18,7 +18,7 @@ type TCP struct {
|
||||
protocol string // the TCP protocol to use.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
var err error
|
||||
|
||||
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
|
||||
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
var cert tls.Certificate
|
||||
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFunc) {
|
||||
for {
|
||||
if atomic.LoadInt64(&l.end) == 1 {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,7 +95,7 @@ func (l *TCP) Serve(establish EstablishFunc) {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go establish(l.id, conn, l.config.Auth)
|
||||
}
|
||||
}
|
||||
@@ -104,8 +106,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
|
@@ -17,12 +17,13 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessage = errors.New("Message type not binary")
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
|
||||
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
|
||||
// websocket compliant connection.
|
||||
wsUpgrader = &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ type Websocket struct {
|
||||
config *Config // configuration values for the listener.
|
||||
address string // the network address to bind to.
|
||||
listen *http.Server // an http server for serving websocket connections.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
establish EstablishFunc // the server's establish conection handler.
|
||||
}
|
||||
|
||||
@@ -161,8 +162,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mochi-co/mqtt/server/internal/clients"
|
||||
"github.com/mochi-co/mqtt/server/internal/packets"
|
||||
"github.com/mochi-co/mqtt/server/internal/topics"
|
||||
"github.com/mochi-co/mqtt/server/internal/utils"
|
||||
"github.com/mochi-co/mqtt/server/listeners"
|
||||
"github.com/mochi-co/mqtt/server/listeners/auth"
|
||||
"github.com/mochi-co/mqtt/server/persistence"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.2" // the server version.
|
||||
Version = "1.1.0" // the server version.
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -207,7 +208,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
var sessionPresent bool
|
||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
if atomic.LoadInt64(&existing.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&existing.State.Done) == 1 {
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
|
||||
}
|
||||
existing.Stop()
|
||||
@@ -258,6 +259,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
})
|
||||
}
|
||||
|
||||
if s.Events.OnConnect != nil {
|
||||
s.Events.OnConnect(events.FromClient(cl), events.Packet(pk))
|
||||
}
|
||||
|
||||
err = cl.Read(s.processPacket)
|
||||
if err != nil {
|
||||
s.closeClient(cl, true)
|
||||
@@ -269,6 +274,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
atomic.AddInt64(&s.System.ClientsConnected, -1)
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
|
||||
|
||||
if s.Events.OnDisconnect != nil {
|
||||
s.Events.OnDisconnect(events.FromClient(cl), err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -412,7 +421,7 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
|
||||
// if an OnMessage hook exists, potentially modify the packet.
|
||||
if s.Events.OnMessage != nil {
|
||||
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
|
||||
if pkx, err := s.Events.OnMessage(events.FromClient(cl), events.Packet(pk)); err == nil {
|
||||
pk = packets.Packet(pkx)
|
||||
}
|
||||
}
|
||||
@@ -449,6 +458,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
|
||||
func (s *Server) publishToSubscribers(pk packets.Packet) {
|
||||
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
|
||||
if client, ok := s.Clients.Get(id); ok {
|
||||
|
||||
// If the AllowClients value is set, only deliver the packet if the subscribed
|
||||
// client exists in the AllowClients value. For use with the OnMessage event hook
|
||||
// in cases where you want to publish messages to clients selectively.
|
||||
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
|
||||
continue
|
||||
}
|
||||
|
||||
out := pk.PublishCopy()
|
||||
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
|
||||
out.FixedHeader.Qos = qos
|
||||
|
@@ -36,6 +36,15 @@ func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func setupServerClient(s *Server) (cl *clients.Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
|
||||
cl.ID = "mochi"
|
||||
cl.AC = new(auth.Allow)
|
||||
cl.Start()
|
||||
return
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
s := New()
|
||||
require.NotNil(t, s)
|
||||
@@ -220,6 +229,203 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
||||
w.Close()
|
||||
}
|
||||
|
||||
func TestServerEventOnConnect(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedPacket events.Packet
|
||||
s.Events.OnConnect = func(cl events.Client, pk events.Packet) {
|
||||
hookedClient = cl
|
||||
hookedPacket = pk
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
w.Write([]byte{byte(packets.Disconnect << 4), 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.NoError(t, errx)
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Connack << 4), 2,
|
||||
0, packets.Accepted,
|
||||
}, <-recv)
|
||||
require.Empty(t, clw.Subscriptions)
|
||||
|
||||
w.Close()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
require.Equal(t, events.Packet(packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Connect,
|
||||
Remaining: 17,
|
||||
},
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
ProtocolVersion: 4,
|
||||
CleanSession: true,
|
||||
Keepalive: 45,
|
||||
ClientIdentifier: "mochi",
|
||||
}), hookedPacket)
|
||||
|
||||
}
|
||||
|
||||
func TestServerEventOnDisconnect(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedErr error
|
||||
s.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
hookedClient = cl
|
||||
hookedErr = err
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
w.Write([]byte{byte(packets.Disconnect << 4), 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.NoError(t, errx)
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Connack << 4), 2,
|
||||
0, packets.Accepted,
|
||||
}, <-recv)
|
||||
require.Empty(t, clw.Subscriptions)
|
||||
|
||||
w.Close()
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
require.Equal(t, nil, hookedErr)
|
||||
}
|
||||
|
||||
func TestServerEventOnDisconnectOnError(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedErr error
|
||||
s.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
hookedClient = cl
|
||||
hookedErr = err
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
|
||||
w.Write([]byte{0, 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
go func() {
|
||||
_, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.Error(t, errx)
|
||||
require.Equal(t, "No valid packet available; 0", errx.Error())
|
||||
fmt.Println(hookedErr)
|
||||
require.Equal(t, errx, hookedErr)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
}
|
||||
|
||||
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
@@ -433,6 +639,10 @@ func TestServerEstablishConnectionReadPacketErr(t *testing.T) {
|
||||
require.Error(t, errx)
|
||||
}
|
||||
|
||||
func TestServerOnDisconnectErr(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestServerWriteClient(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
cl.ID = "mochi"
|
||||
@@ -558,7 +768,7 @@ func TestServerProcessPublishQoS1Retain(t *testing.T) {
|
||||
cl1.ID = "mochi1"
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
_, cl2, r2, w2 := setupClient()
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
@@ -687,7 +897,7 @@ func TestServerProcessPublishOfflineQueuing(t *testing.T) {
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
// Start and stop the receiver client
|
||||
_, cl2, _, _ := setupClient()
|
||||
cl2, _, _ := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
s.Topics.Subscribe("qos0", cl2.ID, 0)
|
||||
@@ -922,7 +1132,7 @@ func TestServerPublishInlineSysTopicError(t *testing.T) {
|
||||
require.Equal(t, int64(0), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishHookOnMessage(t *testing.T) {
|
||||
func TestServerEventOnMessage(t *testing.T) {
|
||||
s, cl1, r1, w1 := setupClient()
|
||||
s.Clients.Add(cl1)
|
||||
s.Topics.Subscribe("a/b/+", cl1.ID, 0)
|
||||
@@ -1071,6 +1281,86 @@ func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) {
|
||||
require.Equal(t, int64(14), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishHookOnMessageAllowClients(t *testing.T) {
|
||||
s, cl1, r1, w1 := setupClient()
|
||||
cl1.ID = "allowed"
|
||||
s.Clients.Add(cl1)
|
||||
s.Topics.Subscribe("a/b/c", cl1.ID, 0)
|
||||
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "not_allowed"
|
||||
s.Clients.Add(cl2)
|
||||
s.Topics.Subscribe("a/b/c", cl2.ID, 0)
|
||||
s.Topics.Subscribe("d/e/f", cl2.ID, 0)
|
||||
|
||||
s.Events.OnMessage = func(cl events.Client, pk events.Packet) (events.Packet, error) {
|
||||
hookedPacket := pk
|
||||
if pk.TopicName == "a/b/c" {
|
||||
hookedPacket.AllowClients = []string{"allowed"}
|
||||
}
|
||||
return hookedPacket, nil
|
||||
}
|
||||
|
||||
ack1 := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(r1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ack1 <- buf
|
||||
}()
|
||||
|
||||
ack2 := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(r2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ack2 <- buf
|
||||
}()
|
||||
|
||||
pk1 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
err := s.processPacket(cl1, pk1)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk2 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("a"),
|
||||
}
|
||||
err = s.processPacket(cl1, pk2)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
w1.Close()
|
||||
w2.Close()
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Publish << 4), 12,
|
||||
0, 5,
|
||||
'a', '/', 'b', '/', 'c',
|
||||
'h', 'e', 'l', 'l', 'o',
|
||||
}, <-ack1)
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Publish << 4), 8,
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'a',
|
||||
}, <-ack2)
|
||||
|
||||
require.Equal(t, int64(24), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPuback(t *testing.T) {
|
||||
s, cl, _, _ := setupClient()
|
||||
cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})
|
||||
@@ -1443,7 +1733,7 @@ func TestServerCloseClientLWT(t *testing.T) {
|
||||
}
|
||||
s.Clients.Add(cl1)
|
||||
|
||||
_, cl2, r2, w2 := setupClient()
|
||||
cl2, r2, w2 := setupServerClient(s)
|
||||
cl2.ID = "mochi2"
|
||||
s.Clients.Add(cl2)
|
||||
|
||||
@@ -1566,14 +1856,14 @@ func TestServerLoadSubscriptions(t *testing.T) {
|
||||
s.Clients.Add(cl)
|
||||
|
||||
subs := []persistence.Subscription{
|
||||
persistence.Subscription{
|
||||
{
|
||||
ID: "test:a/b/c",
|
||||
Client: "test",
|
||||
Filter: "a/b/c",
|
||||
QoS: 1,
|
||||
T: persistence.KSubscription,
|
||||
},
|
||||
persistence.Subscription{
|
||||
{
|
||||
ID: "test:d/e/f",
|
||||
Client: "test",
|
||||
Filter: "d/e/f",
|
||||
@@ -1623,7 +1913,7 @@ func TestServerLoadInflight(t *testing.T) {
|
||||
require.NotNil(t, s)
|
||||
|
||||
msgs := []persistence.Message{
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_if_0",
|
||||
T: persistence.KInflight,
|
||||
Client: "client1",
|
||||
@@ -1633,7 +1923,7 @@ func TestServerLoadInflight(t *testing.T) {
|
||||
Sent: 100,
|
||||
Resends: 0,
|
||||
},
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_if_100",
|
||||
T: persistence.KInflight,
|
||||
Client: "client1",
|
||||
@@ -1668,7 +1958,7 @@ func TestServerLoadRetained(t *testing.T) {
|
||||
require.NotNil(t, s)
|
||||
|
||||
msgs := []persistence.Message{
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_ret_200",
|
||||
T: persistence.KRetained,
|
||||
FixedHeader: persistence.FixedHeader{
|
||||
@@ -1680,7 +1970,7 @@ func TestServerLoadRetained(t *testing.T) {
|
||||
Sent: 100,
|
||||
Resends: 0,
|
||||
},
|
||||
persistence.Message{
|
||||
{
|
||||
ID: "client1_ret_300",
|
||||
T: persistence.KRetained,
|
||||
FixedHeader: persistence.FixedHeader{
|
||||
|
21
server/system/system_test.go
Normal file
21
server/system/system_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInfoAlignment(t *testing.T) {
|
||||
typ := reflect.TypeOf(Info{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
switch f.Type.Kind() {
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
require.Equalf(t, uintptr(0), f.Offset%8,
|
||||
"%s requires 64-bit alignment for atomic: offset %d",
|
||||
f.Name, f.Offset)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user