Compare commits

..

37 Commits

Author SHA1 Message Date
mochi
86e0a5827e Update version to 1.1.0 2022-01-26 20:49:53 +00:00
mochi
06c399b606 indicate ARM32 compatibility 2022-01-26 20:49:42 +00:00
JB
ed117f67a1 Merge pull request #22 from mochi-co/feature/32bit-compatibility
ARM32 Compatibility
2022-01-26 20:36:13 +00:00
JB
880a3299e1 Merge pull request #19 from rkennedy/bugfix/32-bit-atomic-alignment
Improve 32-bit compatibility
2022-01-26 08:02:57 +00:00
Rob Kennedy
1c408d05be Fix encodeLength for 32-bit platforms
When `int` is 32 bits, `MaxInt64` doesn't fit. It's apparent that
`encodeLength` expects to handle 64-bit inputs, so let's make that
explicit, which allows the test to run on all platforms.
2022-01-25 00:22:26 -06:00
Rob Kennedy
fce495f83e Avoid race condition when closing listeners
"Atomic load" followed by "atomic store" is not itself an atomic
operation. This commit replaces that sequence with CompareAndSwap
instead.
2022-01-25 00:22:26 -06:00
Rob Kennedy
471ca00a64 Make atomics work on 32-bit systems
On 32-bit systems, `atomic` requires its 64-bit arguments to have 64-bit
alignment, but the compiler doesn't help ensure that's the case. In this
commit, fields that don't need to hold large numbers have been converted
to 32-bit types, which are always aligned correctly on all platforms.
For fields that may hold large numeric values, padding has been added to
get the necessary alignment, and tests have been added to avoid
regressions.
2022-01-25 00:22:26 -06:00
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
24 changed files with 514 additions and 87 deletions

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cmd/mqtt
.DS_Store
.DS_Store
server/persistence/bolt/testbolt.db

View File

@@ -25,9 +25,11 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (currently `onMessage`)
- Basic Event Hooks (currently `OnMessage`, `OnConnect`, `OnDisconnect`).
- ARM32 Compatible.
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
@@ -105,8 +107,30 @@ err := server.AddListener(tcp, &listeners.Config{
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
`server.Events.OnMessage` is called when a Publish packet is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
@@ -124,7 +148,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
}
```
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.

View File

@@ -44,6 +44,16 @@ func main() {
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
@@ -54,6 +64,12 @@ func main() {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}

View File

@@ -6,7 +6,9 @@ import (
)
type Events struct {
OnMessage // published message receieved.
OnMessage // published message receieved.
OnConnect // client connected.
OnDisconnect // client disconnected.
}
type Packet packets.Packet
@@ -17,7 +19,7 @@ type Client struct {
}
// FromClient returns an event client from a client.
func FromClient(cl clients.Client) Client {
func FromClient(cl *clients.Client) Client {
return Client{
ID: cl.ID,
Listener: cl.Listener,
@@ -34,3 +36,11 @@ func FromClient(cl clients.Client) Client {
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)

View File

@@ -24,12 +24,13 @@ type Buffer struct {
block int // the size of the R/W block.
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
_ int32 // align the next fields to an 8-byte boundary for atomic access.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
done int64 // indicates that the buffer is closed.
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
@@ -127,7 +128,7 @@ func (b *Buffer) awaitEmpty(n int) error {
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
@@ -146,7 +147,7 @@ func (b *Buffer) awaitFilled(n int) error {
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
@@ -196,7 +197,7 @@ func (b *Buffer) CapDelta() int {
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreInt64(&b.done, 1)
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()

View File

@@ -5,6 +5,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, int64(1), buf.done)
require.Equal(t, uint32(1), buf.done)
}

View File

@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreInt64(&b.State, 1)
defer atomic.StoreInt64(&b.State, 0)
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}

View File

@@ -60,7 +60,7 @@ func TestReadFromWrap(t *testing.T) {
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -116,7 +116,7 @@ func TestReadEnded(t *testing.T) {
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -32,10 +32,10 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2)
defer atomic.StoreInt64(&b.State, 0)
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}

View File

@@ -59,7 +59,7 @@ func TestWriteTo(t *testing.T) {
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -74,7 +74,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
@@ -99,12 +99,12 @@ type Client struct {
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
systemInfo *system.Info // pointers to server system info.
}
// State tracks the state of the client.
type State struct {
Done int64 // atomic counter which indicates that the client has closed.
Done uint32 // atomic counter which indicates that the client has closed.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
@@ -114,11 +114,11 @@ type State struct {
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
conn: c,
r: r,
w: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
internal: make(map[uint16]InflightMessage),
},
@@ -240,7 +240,7 @@ func (cl *Client) Start() {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
@@ -252,7 +252,7 @@ func (cl *Client) Stop() {
cl.conn.Close()
cl.State.endedR.Wait()
atomic.StoreInt64(&cl.State.Done, 1)
atomic.StoreUint32(&cl.State.Done, 1)
})
}
@@ -300,7 +300,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
@@ -309,7 +309,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
return nil
}
@@ -334,7 +334,7 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
@@ -345,7 +345,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
@@ -359,7 +359,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
@@ -391,7 +391,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
@@ -407,7 +407,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
@@ -443,8 +443,9 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)

View File

@@ -316,8 +316,8 @@ func TestClientStart(t *testing.T) {
cl.Start()
defer cl.Stop()
time.Sleep(time.Millisecond)
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.r.State))
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.w.State))
}
func BenchmarkClientStart(b *testing.B) {
@@ -340,7 +340,7 @@ func TestClientReadFixedHeader(t *testing.T) {
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
tail, head := cl.r.GetPos()
require.Equal(t, int64(2), tail)
@@ -456,8 +456,8 @@ func TestClientReadOK(t *testing.T) {
},
})
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
}
@@ -574,7 +574,7 @@ func TestClientReadPacket(t *testing.T) {
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
}
}
}
@@ -647,10 +647,10 @@ func TestClientWritePacket(t *testing.T) {
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop()
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
}
}
}

View File

@@ -16,7 +16,7 @@ type FixedHeader struct {
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
encodeLength(buf, int64(fh.Remaining))
}
// decode extracts the specification bits from the header byte.
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int) {
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128

View File

@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int
have int64
want []byte
}{
{

View File

@@ -109,6 +109,10 @@ type Packet struct {
Topics []string
Qoss []byte
// If AllowClients set, only deliver to clients in the client allow list.
// For use with the OnMessage event hook.
AllowClients []string
ReturnCodes []byte // Suback
}

View File

@@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@@ -22,7 +22,7 @@ type HTTPStats struct {
system *system.Info // pointers to the server data.
address string // the network address to bind to.
listen *http.Server // the http server.
end int64 // ensure the close methods are only called once.}
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -44,11 +44,8 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for {
select {
case <-l.done:
return
}
for range l.done {
return
}
}

View File

@@ -18,7 +18,7 @@ type TCP struct {
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
end int64 // ensure the close methods are only called once.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
var err error
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadInt64(&l.end) == 1 {
if atomic.LoadUint32(&l.end) == 1 {
return
}
@@ -93,7 +95,7 @@ func (l *TCP) Serve(establish EstablishFunc) {
return
}
if atomic.LoadInt64(&l.end) == 0 {
if atomic.LoadUint32(&l.end) == 0 {
go establish(l.id, conn, l.config.Auth)
}
}
@@ -104,8 +106,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}

View File

@@ -17,12 +17,13 @@ import (
)
var (
ErrInvalidMessage = errors.New("Message type not binary")
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
@@ -33,7 +34,7 @@ type Websocket struct {
config *Config // configuration values for the listener.
address string // the network address to bind to.
listen *http.Server // an http server for serving websocket connections.
end int64 // ensure the close methods are only called once.
end uint32 // ensure the close methods are only called once.
establish EstablishFunc // the server's establish conection handler.
}
@@ -161,8 +162,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -14,6 +14,7 @@ import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/internal/utils"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence"
@@ -21,7 +22,7 @@ import (
)
const (
Version = "1.0.2" // the server version.
Version = "1.1.0" // the server version.
)
var (
@@ -207,7 +208,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
var sessionPresent bool
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
if atomic.LoadInt64(&existing.State.Done) == 1 {
if atomic.LoadUint32(&existing.State.Done) == 1 {
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
}
existing.Stop()
@@ -258,6 +259,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
})
}
if s.Events.OnConnect != nil {
s.Events.OnConnect(events.FromClient(cl), events.Packet(pk))
}
err = cl.Read(s.processPacket)
if err != nil {
s.closeClient(cl, true)
@@ -269,6 +274,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
atomic.AddInt64(&s.System.ClientsConnected, -1)
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(events.FromClient(cl), err)
}
return err
}
@@ -412,7 +421,7 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// if an OnMessage hook exists, potentially modify the packet.
if s.Events.OnMessage != nil {
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
if pkx, err := s.Events.OnMessage(events.FromClient(cl), events.Packet(pk)); err == nil {
pk = packets.Packet(pkx)
}
}
@@ -449,6 +458,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) publishToSubscribers(pk packets.Packet) {
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
if client, ok := s.Clients.Get(id); ok {
// If the AllowClients value is set, only deliver the packet if the subscribed
// client exists in the AllowClients value. For use with the OnMessage event hook
// in cases where you want to publish messages to clients selectively.
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
continue
}
out := pk.PublishCopy()
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
out.FixedHeader.Qos = qos

View File

@@ -36,6 +36,15 @@ func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
return
}
func setupServerClient(s *Server) (cl *clients.Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
cl.ID = "mochi"
cl.AC = new(auth.Allow)
cl.Start()
return
}
func TestNew(t *testing.T) {
s := New()
require.NotNil(t, s)
@@ -220,6 +229,203 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
w.Close()
}
func TestServerEventOnConnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedPacket events.Packet
s.Events.OnConnect = func(cl events.Client, pk events.Packet) {
hookedClient = cl
hookedPacket = pk
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, events.Packet(packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connect,
Remaining: 17,
},
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
ProtocolVersion: 4,
CleanSession: true,
Keepalive: 45,
ClientIdentifier: "mochi",
}), hookedPacket)
}
func TestServerEventOnDisconnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, nil, hookedErr)
}
func TestServerEventOnDisconnectOnError(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{0, 0})
}()
// Receive the Connack
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.Error(t, errx)
require.Equal(t, "No valid packet available; 0", errx.Error())
fmt.Println(hookedErr)
require.Equal(t, errx, hookedErr)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
}
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
s := New()
@@ -433,6 +639,10 @@ func TestServerEstablishConnectionReadPacketErr(t *testing.T) {
require.Error(t, errx)
}
func TestServerOnDisconnectErr(t *testing.T) {
}
func TestServerWriteClient(t *testing.T) {
s, cl, r, w := setupClient()
cl.ID = "mochi"
@@ -558,7 +768,7 @@ func TestServerProcessPublishQoS1Retain(t *testing.T) {
cl1.ID = "mochi1"
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -687,7 +897,7 @@ func TestServerProcessPublishOfflineQueuing(t *testing.T) {
s.Clients.Add(cl1)
// Start and stop the receiver client
_, cl2, _, _ := setupClient()
cl2, _, _ := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
s.Topics.Subscribe("qos0", cl2.ID, 0)
@@ -922,7 +1132,7 @@ func TestServerPublishInlineSysTopicError(t *testing.T) {
require.Equal(t, int64(0), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessage(t *testing.T) {
func TestServerEventOnMessage(t *testing.T) {
s, cl1, r1, w1 := setupClient()
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/+", cl1.ID, 0)
@@ -1071,6 +1281,86 @@ func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) {
require.Equal(t, int64(14), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessageAllowClients(t *testing.T) {
s, cl1, r1, w1 := setupClient()
cl1.ID = "allowed"
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/c", cl1.ID, 0)
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "not_allowed"
s.Clients.Add(cl2)
s.Topics.Subscribe("a/b/c", cl2.ID, 0)
s.Topics.Subscribe("d/e/f", cl2.ID, 0)
s.Events.OnMessage = func(cl events.Client, pk events.Packet) (events.Packet, error) {
hookedPacket := pk
if pk.TopicName == "a/b/c" {
hookedPacket.AllowClients = []string{"allowed"}
}
return hookedPacket, nil
}
ack1 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r1)
if err != nil {
panic(err)
}
ack1 <- buf
}()
ack2 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r2)
if err != nil {
panic(err)
}
ack2 <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
}
err := s.processPacket(cl1, pk1)
require.NoError(t, err)
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "d/e/f",
Payload: []byte("a"),
}
err = s.processPacket(cl1, pk2)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
w1.Close()
w2.Close()
require.Equal(t, []byte{
byte(packets.Publish << 4), 12,
0, 5,
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o',
}, <-ack1)
require.Equal(t, []byte{
byte(packets.Publish << 4), 8,
0, 5,
'd', '/', 'e', '/', 'f',
'a',
}, <-ack2)
require.Equal(t, int64(24), s.System.BytesSent)
}
func TestServerProcessPuback(t *testing.T) {
s, cl, _, _ := setupClient()
cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})
@@ -1443,7 +1733,7 @@ func TestServerCloseClientLWT(t *testing.T) {
}
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -1566,14 +1856,14 @@ func TestServerLoadSubscriptions(t *testing.T) {
s.Clients.Add(cl)
subs := []persistence.Subscription{
persistence.Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
},
persistence.Subscription{
{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
@@ -1623,7 +1913,7 @@ func TestServerLoadInflight(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_if_0",
T: persistence.KInflight,
Client: "client1",
@@ -1633,7 +1923,7 @@ func TestServerLoadInflight(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_if_100",
T: persistence.KInflight,
Client: "client1",
@@ -1668,7 +1958,7 @@ func TestServerLoadRetained(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
@@ -1680,7 +1970,7 @@ func TestServerLoadRetained(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{

View File

@@ -0,0 +1,21 @@
package system
import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestInfoAlignment(t *testing.T) {
typ := reflect.TypeOf(Info{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
switch f.Type.Kind() {
case reflect.Int64, reflect.Uint64:
require.Equalf(t, uintptr(0), f.Offset%8,
"%s requires 64-bit alignment for atomic: offset %d",
f.Name, f.Offset)
}
}
}