Compare commits

..

48 Commits

Author SHA1 Message Date
mochi
7f76445cc8 update server version 2022-01-30 10:39:49 +00:00
JB
b1c01792cd Merge pull request #24 from mochi-co/feature/optimise-struct-fields
Optimise Struct Fields + Fixes
2022-01-30 10:38:28 +00:00
mochi
eda03d4338 optimise Server struct 2022-01-30 10:30:34 +00:00
mochi
18070f1f57 pass byte pool by address 2022-01-30 10:30:19 +00:00
mochi
7f10c28a37 remove println 2022-01-30 10:30:01 +00:00
mochi
122531bb27 Pass inflight by address to avoid lock copying 2022-01-28 21:07:10 +00:00
mochi
e6dbcae428 Correct function signature 2022-01-28 21:06:57 +00:00
mochi
98875de568 Update test to match new FixedHeader struct 2022-01-28 21:06:43 +00:00
mochi
c9fd9451af Prevent locks from being copied 2022-01-28 21:06:24 +00:00
mochi
6550b8d680 8bit align struct fields 2022-01-28 21:05:50 +00:00
mochi
a60c96c889 Update comment for clarity 2022-01-28 21:04:15 +00:00
mochi
86e0a5827e Update version to 1.1.0 2022-01-26 20:49:53 +00:00
mochi
06c399b606 indicate ARM32 compatibility 2022-01-26 20:49:42 +00:00
JB
ed117f67a1 Merge pull request #22 from mochi-co/feature/32bit-compatibility
ARM32 Compatibility
2022-01-26 20:36:13 +00:00
JB
880a3299e1 Merge pull request #19 from rkennedy/bugfix/32-bit-atomic-alignment
Improve 32-bit compatibility
2022-01-26 08:02:57 +00:00
Rob Kennedy
1c408d05be Fix encodeLength for 32-bit platforms
When `int` is 32 bits, `MaxInt64` doesn't fit. It's apparent that
`encodeLength` expects to handle 64-bit inputs, so let's make that
explicit, which allows the test to run on all platforms.
2022-01-25 00:22:26 -06:00
Rob Kennedy
fce495f83e Avoid race condition when closing listeners
"Atomic load" followed by "atomic store" is not itself an atomic
operation. This commit replaces that sequence with CompareAndSwap
instead.
2022-01-25 00:22:26 -06:00
Rob Kennedy
471ca00a64 Make atomics work on 32-bit systems
On 32-bit systems, `atomic` requires its 64-bit arguments to have 64-bit
alignment, but the compiler doesn't help ensure that's the case. In this
commit, fields that don't need to hold large numbers have been converted
to 32-bit types, which are always aligned correctly on all platforms.
For fields that may hold large numeric values, padding has been added to
get the necessary alignment, and tests have been added to avoid
regressions.
2022-01-25 00:22:26 -06:00
mochi
a2c0749640 Update server version to 1.0.5 2022-01-24 18:46:34 +00:00
JB
37293aeecf Merge pull request #18 from mochi-co/feature/connect-disconnect-hooks
OnConnect and OnDisconnect Event Hooks
2022-01-24 18:44:39 +00:00
mochi
7a2d4db6a4 Update for OnConnect and OnDisconnect hooks 2022-01-24 18:42:09 +00:00
mochi
03d2a8bc82 Add tests for OnConnect, OnDisconnect 2022-01-24 18:29:18 +00:00
mochi
4b51e5c7d1 Add OnConnect and OnDisconnect hooks to example 2022-01-24 17:42:33 +00:00
mochi
d15ad682bf Call OnDisconnect Event if applicable 2022-01-24 17:42:19 +00:00
mochi
130ffcbb53 Add OnDisconnect Event Hook 2022-01-24 17:42:04 +00:00
mochi
33cf2f991b Add testbolt file to ignore list 2022-01-24 17:41:46 +00:00
mochi
a360ea6a6c Call OnConnect Event if applicable 2022-01-24 17:37:11 +00:00
mochi
ae3aa0d3fa Add OnConnect event hook 2022-01-24 17:36:50 +00:00
mochi
811ae0e1be Prevent locks being copied by passing non-pointer to FromClient 2022-01-24 17:36:14 +00:00
JB
51d6825430 Merge pull request #15 from ClarkQAQ/master
Fixed some bugs, wish the project better and better
2022-01-17 10:08:20 +00:00
clark
514288c53e update tcp.go maybe this will be better 2022-01-16 20:06:49 +08:00
clark
957fc0a049 fix local variable black hole 2022-01-16 18:23:45 +08:00
clark
03f94f948a update mock.go plase use range 2022-01-16 18:22:37 +08:00
clark
1bc752a2b8 fix [ST1005] strings should not be capitalized 2022-01-16 18:21:33 +08:00
clark
b9db59ba12 update websocket.go fix check origin 2022-01-16 18:20:06 +08:00
JB
c0ef58c363 Update README.md 2022-01-14 17:48:21 +00:00
JB
994adea3b4 Merge pull request #14 from mochi-co/feature/allow-clients-value
Add AllowClients Field to packets
2022-01-14 17:38:29 +00:00
mochi
fc61cc9be5 Add example for AllowClients field 2022-01-14 17:04:55 +00:00
mochi
22d7338878 Add test for AllowClients field 2022-01-14 17:04:39 +00:00
mochi
3f28515706 Remove unnecessary type declarations 2022-01-14 17:04:21 +00:00
mochi
7d73ce9caf Add setupServerClients to inherit existing server instance
previously new clients generated a new server object, so system stats were not shared. This change ensures all test clients use the same server
2022-01-14 17:04:01 +00:00
mochi
0758bc961c Add AllowClients check in publishToSubscribers
If AllowClients has been set on a packet, ensure only clients in the slice are sent the message
2022-01-14 17:02:31 +00:00
mochi
8472b9ae8a use .systemInfo instead of .system for clarity 2022-01-14 17:01:42 +00:00
mochi
530a018e80 use .systemInfo instead of .system for clarity 2022-01-14 17:01:31 +00:00
mochi
0b594afb4e Add AllowClients field to packets
AllowClients field can be specified during onMessage event to selectively deliver messages
2022-01-14 16:59:17 +00:00
mochi
9d0ea957bb Increment server version 2022-01-14 16:58:48 +00:00
mochi
8067785ac4 Add tests for InSliceString 2022-01-14 16:58:33 +00:00
mochi
6ffc8a8388 Add InSliceString function
Check if a slice of strings contains a string (until slices package available)
2022-01-14 16:58:21 +00:00
28 changed files with 622 additions and 212 deletions

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cmd/mqtt
.DS_Store
.DS_Store
server/persistence/bolt/testbolt.db

View File

@@ -25,9 +25,11 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
- Interfaces for Client Authentication and Topic access control.
- Bolt-backed persistence and storage interfaces.
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
- Basic Event Hooks (currently `onMessage`)
- Basic Event Hooks (currently `OnMessage`, `OnConnect`, `OnDisconnect`).
- ARM32 Compatible.
#### Roadmap
- Please open an issue to request new features or event hooks.
- MQTT v5 compatibility?
#### Using the Broker
@@ -105,8 +107,30 @@ err := server.AddListener(tcp, &listeners.Config{
#### Event Hooks
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
##### OnConnect
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
```go
import "github.com/mochi-co/mqtt/server/events"
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
```
##### OnDisconnect
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
```go
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
```
##### OnMessage
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
`server.Events.OnMessage` is called when a Publish packet is received. The method receives the published message and information about the client who published it.
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
@@ -124,7 +148,8 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
}
```
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
#### Direct Publishing
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.

View File

@@ -44,6 +44,16 @@ func main() {
}
}()
// Add OnConnect Event Hook
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
}
// Add OnDisconnect Event Hook
server.Events.OnDisconnect = func(cl events.Client, err error) {
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
}
// Add OnMessage Event Hook
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
pkx = pk
@@ -54,6 +64,12 @@ func main() {
fmt.Printf("< OnMessage received message from client %s: %s\n", cl.ID, string(pkx.Payload))
}
// Example of using AllowClients to selectively deliver/drop messages.
// Only a client with the id of `allowed-client` will received messages on the topic.
if pkx.TopicName == "a/b/restricted" {
pkx.AllowClients = []string{"allowed-client"} // slice of known client ids
}
return pkx, nil
}

View File

@@ -6,7 +6,9 @@ import (
)
type Events struct {
OnMessage // published message receieved.
OnMessage // published message receieved.
OnConnect // client connected.
OnDisconnect // client disconnected.
}
type Packet packets.Packet
@@ -17,7 +19,7 @@ type Client struct {
}
// FromClient returns an event client from a client.
func FromClient(cl clients.Client) Client {
func FromClient(cl *clients.Client) Client {
return Client{
ID: cl.ID,
Listener: cl.Listener,
@@ -34,3 +36,11 @@ func FromClient(cl clients.Client) Client {
// This function will block message dispatching until it returns. To minimise this,
// have the function open a new goroutine on the embedding side.
type OnMessage func(Client, Packet) (Packet, error)
// OnConnect is called when a client successfully connects to the broker.
type OnConnect func(Client, Packet)
// OnDisconnect is called when a client disconnects to the broker. An error value
// is passed to the function if the client disconnected abnormally, otherwise it
// will be nil on a normal disconnect.
type OnDisconnect func(Client, error)

View File

@@ -15,26 +15,25 @@ var (
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// buffer contains core values and methods to be included in a reader or writer.
type Buffer struct {
Mu sync.RWMutex // the buffer needs it's own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
buf []byte // the bytes buffer.
tmp []byte // a temporary buffer.
Mu sync.RWMutex // the buffer needs its own mutex to work properly.
ID string // the identifier of the buffer. This is used in debug output.
head int64 // the current position in the sequence - a forever increasing index.
tail int64 // the committed position in the sequence - a forever increasing index.
rcond *sync.Cond // the sync condition for the buffer reader.
wcond *sync.Cond // the sync condition for the buffer writer.
done int64 // indicates that the buffer is closed.
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
size int // the size of the buffer.
mask int // a bitmask of the buffer size (size-1).
block int // the size of the R/W block.
done uint32 // indicates that the buffer is closed.
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
}
// NewBuffer returns a new instance of buffer. You should call NewReader or
// NewWriter instead of this function.
func NewBuffer(size, block int) Buffer {
func NewBuffer(size, block int) *Buffer {
if size == 0 {
size = DefaultBufferSize
}
@@ -47,7 +46,7 @@ func NewBuffer(size, block int) Buffer {
size = 2 * block
}
return Buffer{
return &Buffer{
size: size,
mask: size - 1,
block: block,
@@ -59,14 +58,14 @@ func NewBuffer(size, block int) Buffer {
// NewBufferFromSlice returns a new instance of buffer using a
// pre-existing byte slice.
func NewBufferFromSlice(block int, buf []byte) Buffer {
func NewBufferFromSlice(block int, buf []byte) *Buffer {
l := len(buf)
if block == 0 {
block = DefaultBlockSize
}
b := Buffer{
b := &Buffer{
size: l,
mask: l - 1,
block: block,
@@ -127,7 +126,7 @@ func (b *Buffer) awaitEmpty(n int) error {
// then wait until tail has moved.
b.rcond.L.Lock()
for !b.checkEmpty(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.rcond.L.Unlock()
return io.EOF
}
@@ -146,7 +145,7 @@ func (b *Buffer) awaitFilled(n int) error {
// the forever-incrementing tail and head integers.
b.wcond.L.Lock()
for !b.checkFilled(n) {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
b.wcond.L.Unlock()
return io.EOF
}
@@ -196,7 +195,7 @@ func (b *Buffer) CapDelta() int {
// Stop signals the buffer to stop processing.
func (b *Buffer) Stop() {
atomic.StoreInt64(&b.done, 1)
atomic.StoreUint32(&b.done, 1)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()

View File

@@ -5,6 +5,7 @@ import (
"sync/atomic"
"testing"
"time"
"unsafe"
"github.com/stretchr/testify/require"
)
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
require.Equal(t, 256, cap(buf.buf))
}
func TestAtomicAlignment(t *testing.T) {
var b Buffer
offset := unsafe.Offsetof(b.head)
require.Equalf(t, uintptr(0), offset%8,
"head requires 64-bit alignment for atomic: offset %d", offset)
offset = unsafe.Offsetof(b.tail)
require.Equalf(t, uintptr(0), offset%8,
"tail requires 64-bit alignment for atomic: offset %d", offset)
}
func TestGetPos(t *testing.T) {
buf := NewBuffer(16, 4)
tail, head := buf.GetPos()
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
o <- buf.awaitFilled(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
o <- buf.awaitEmpty(4)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
o <- buf.CommitTail(5)
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
func TestStop(t *testing.T) {
buf := NewBuffer(16, 4)
buf.Stop()
require.Equal(t, int64(1), buf.done)
require.Equal(t, uint32(1), buf.done)
}

View File

@@ -6,13 +6,13 @@ import (
// BytesPool is a pool of []byte.
type BytesPool struct {
pool sync.Pool
pool *sync.Pool
}
// NewBytesPool returns a sync.pool of []byte.
func NewBytesPool(n int) BytesPool {
return BytesPool{
pool: sync.Pool{
func NewBytesPool(n int) *BytesPool {
return &BytesPool{
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, n)
},

View File

@@ -7,7 +7,7 @@ import (
// Reader is a circular buffer for reading data from an io.Reader.
type Reader struct {
Buffer
*Buffer
}
// NewReader returns a new Circular Reader.
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
atomic.StoreInt64(&b.State, 1)
defer atomic.StoreInt64(&b.State, 0)
atomic.StoreUint32(&b.State, 1)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 {
if atomic.LoadUint32(&b.done) == 1 {
return total, nil
}

View File

@@ -60,7 +60,7 @@ func TestReadFromWrap(t *testing.T) {
}()
time.Sleep(time.Millisecond * 100)
go func() {
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
@@ -116,7 +116,7 @@ func TestReadEnded(t *testing.T) {
o <- err
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -1,14 +1,14 @@
package circ
import (
"fmt"
"io"
"log"
"sync/atomic"
)
// Writer is a circular buffer for writing data to an io.Writer.
type Writer struct {
Buffer
*Buffer
}
// NewWriter returns a pointer to a new Circular Writer.
@@ -31,11 +31,11 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
atomic.StoreInt64(&b.State, 2)
defer atomic.StoreInt64(&b.State, 0)
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
atomic.StoreUint32(&b.State, 2)
defer atomic.StoreUint32(&b.State, 0)
for {
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
return total, io.EOF
}
@@ -59,14 +59,12 @@ func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
p = append(p, b.buf[rTail:rHead]...)
}
//fmt.Println("writing", p)
n, err = w.Write(p)
total += n
total += int64(n)
if err != nil {
fmt.Println("writing err", err)
log.Println("error writing to buffer io.Writer;", err)
return
}
//fmt.Println("written", n)
// Move the tail forward the bytes written and broadcast change.
atomic.StoreInt64(&b.tail, tail+int64(n))

View File

@@ -52,14 +52,14 @@ func TestWriteTo(t *testing.T) {
var b bytes.Buffer
w := bufio.NewWriter(&b)
nc := make(chan int)
nc := make(chan int64)
go func() {
n, _ := buf.WriteTo(w)
nc <- n
}()
time.Sleep(time.Millisecond * 100)
atomic.StoreInt64(&buf.done, 1)
atomic.StoreUint32(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()

View File

@@ -74,7 +74,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
clients := make([]*Client, 0, cl.Len())
cl.RLock()
for _, v := range cl.internal {
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
clients = append(clients, v)
}
}
@@ -84,42 +84,42 @@ func (cl *Clients) GetByListener(id string) []*Client {
// Client contains information about a client known by the broker.
type Client struct {
sync.RWMutex
State State // the operational state of the client.
LWT LWT // the last will and testament for the client.
Inflight *Inflight // a map of in-flight qos messages.
sync.RWMutex // mutex
Username []byte // the username the client authenticated with.
AC auth.Controller // an auth controller inherited from the listener.
Listener string // the id of the listener the client is connected to.
ID string // the client id.
conn net.Conn // the net.Conn used to establish the connection.
r *circ.Reader // a reader for reading incoming bytes.
w *circ.Writer // a writer for writing outgoing bytes.
ID string // the client id.
AC auth.Controller // an auth controller inherited from the listener.
Subscriptions topics.Subscriptions // a map of the subscription filters a client maintains.
Listener string // the id of the listener the client is connected to.
Inflight Inflight // a map of in-flight qos messages.
Username []byte // the username the client authenticated with.
systemInfo *system.Info // pointers to server system info.
packetID uint32 // the current highest packetID.
keepalive uint16 // the number of seconds the connection can wait.
cleanSession bool // indicates if the client expects a clean-session.
packetID uint32 // the current highest packetID.
LWT LWT // the last will and testament for the client.
State State // the operational state of the client.
system *system.Info // pointers to server system info.
}
// State tracks the state of the client.
type State struct {
Done int64 // atomic counter which indicates that the client has closed.
started *sync.WaitGroup // tracks the goroutines which have been started.
endedW *sync.WaitGroup // tracks when the writer has ended.
endedR *sync.WaitGroup // tracks when the reader has ended.
Done uint32 // atomic counter which indicates that the client has closed.
endOnce sync.Once // only end once.
}
// NewClient returns a new instance of Client.
func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Client {
cl := &Client{
conn: c,
r: r,
w: w,
system: s,
keepalive: defaultKeepalive,
Inflight: Inflight{
conn: c,
r: r,
w: w,
systemInfo: s,
keepalive: defaultKeepalive,
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -139,7 +139,7 @@ func NewClient(c net.Conn, r *circ.Reader, w *circ.Writer, s *system.Info) *Clie
// method is typically called by the persistence restoration system.
func NewClientStub(s *system.Info) *Client {
return &Client{
Inflight: Inflight{
Inflight: &Inflight{
internal: make(map[uint16]InflightMessage),
},
Subscriptions: make(map[string]byte),
@@ -240,7 +240,7 @@ func (cl *Client) Start() {
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop() {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return
}
@@ -252,7 +252,7 @@ func (cl *Client) Stop() {
cl.conn.Close()
cl.State.endedR.Wait()
atomic.StoreInt64(&cl.State.Done, 1)
atomic.StoreUint32(&cl.State.Done, 1)
})
}
@@ -300,7 +300,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// Having successfully read n bytes, commit the tail forward.
cl.r.CommitTail(n)
atomic.AddInt64(&cl.system.BytesRecv, int64(n))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(n))
return nil
}
@@ -309,7 +309,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
// an error is encountered (or the connection is closed).
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
for {
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
return nil
}
@@ -334,7 +334,7 @@ func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.system.MessagesRecv, 1)
atomic.AddInt64(&cl.systemInfo.MessagesRecv, 1)
pk.FixedHeader = *fh
if pk.FixedHeader.Remaining == 0 {
@@ -345,7 +345,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.system.BytesRecv, int64(len(p)))
atomic.AddInt64(&cl.systemInfo.BytesRecv, int64(len(p)))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
@@ -359,7 +359,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.system.PublishRecv, 1)
atomic.AddInt64(&cl.systemInfo.PublishRecv, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
@@ -391,7 +391,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if atomic.LoadInt64(&cl.State.Done) == 1 {
if atomic.LoadUint32(&cl.State.Done) == 1 {
return 0, ErrConnectionClosed
}
@@ -407,7 +407,7 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
case packets.Publish:
err = pk.PublishEncode(buf)
if err == nil {
atomic.AddInt64(&cl.system.PublishSent, 1)
atomic.AddInt64(&cl.systemInfo.PublishSent, 1)
}
case packets.Puback:
err = pk.PubackEncode(buf)
@@ -443,8 +443,9 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
if err != nil {
return
}
atomic.AddInt64(&cl.system.BytesSent, int64(n))
atomic.AddInt64(&cl.system.MessagesSent, 1)
atomic.AddInt64(&cl.systemInfo.BytesSent, int64(n))
atomic.AddInt64(&cl.systemInfo.MessagesSent, 1)
cl.refreshDeadline(cl.keepalive)
@@ -453,8 +454,8 @@ func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
// LWT contains the last will and testament details for a client connection.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}

View File

@@ -316,8 +316,8 @@ func TestClientStart(t *testing.T) {
cl.Start()
defer cl.Stop()
time.Sleep(time.Millisecond)
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.r.State))
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.w.State))
}
func BenchmarkClientStart(b *testing.B) {
@@ -340,7 +340,7 @@ func TestClientReadFixedHeader(t *testing.T) {
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
tail, head := cl.r.GetPos()
require.Equal(t, int64(2), tail)
@@ -456,8 +456,8 @@ func TestClientReadOK(t *testing.T) {
},
})
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.system.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.system.MessagesRecv))
require.Equal(t, int64(len(b)), atomic.LoadInt64(&cl.systemInfo.BytesRecv))
require.Equal(t, int64(2), atomic.LoadInt64(&cl.systemInfo.MessagesRecv))
}
@@ -574,7 +574,7 @@ func TestClientReadPacket(t *testing.T) {
require.Equal(t, tt.packet, pk, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishRecv))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishRecv))
}
}
}
@@ -647,10 +647,10 @@ func TestClientWritePacket(t *testing.T) {
require.Equal(t, tt.bytes, <-o, "Mismatched packet: [i:%d] %d", i, tt.bytes[0])
cl.Stop()
require.Equal(t, int64(n), atomic.LoadInt64(&cl.system.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.MessagesSent))
require.Equal(t, int64(n), atomic.LoadInt64(&cl.systemInfo.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.MessagesSent))
if tt.packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.system.PublishSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.systemInfo.PublishSent))
}
}
}

View File

@@ -6,17 +6,17 @@ import (
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
encodeLength(buf, int64(fh.Remaining))
}
// decode extracts the specification bits from the header byte.
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
}
// encodeLength writes length bits for the header.
func encodeLength(buf *bytes.Buffer, length int) {
func encodeLength(buf *bytes.Buffer, length int64) {
for {
digit := byte(length % 128)
length /= 128

View File

@@ -18,130 +18,130 @@ type fixedHeaderTable struct {
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Connack, false, 0, false, 0},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Publish, false, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Publish, false, 1, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Publish, false, 2, false, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
flagError: true,
},
}
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
func TestEncodeLength(t *testing.T) {
tt := []struct {
have int
have int64
want []byte
}{
{

View File

@@ -76,40 +76,31 @@ var (
// packet structs, this is a single concrete packet type to cover all packet
// types, which allows us to take advantage of various compiler optimizations.
type Packet struct {
FixedHeader FixedHeader
PacketID uint16
// Connect
FixedHeader FixedHeader
AllowClients []string // For use with OnMessage event hook.
Topics []string
ReturnCodes []byte
ProtocolName []byte
Qoss []byte
Payload []byte
Username []byte
Password []byte
WillMessage []byte
ClientIdentifier string
TopicName string
WillTopic string
PacketID uint16
Keepalive uint16
ReturnCode byte
ProtocolVersion byte
WillQos byte
ReservedBit byte
CleanSession bool
WillFlag bool
WillQos byte
WillRetain bool
UsernameFlag bool
PasswordFlag bool
ReservedBit byte
Keepalive uint16
ClientIdentifier string
WillTopic string
WillMessage []byte
Username []byte
Password []byte
// Connack
SessionPresent bool
ReturnCode byte
// Publish
TopicName string
Payload []byte
// Subscribe, Unsubscribe
Topics []string
Qoss []byte
ReturnCodes []byte // Suback
SessionPresent bool
}
// ConnectEncode encodes a connect packet.

View File

@@ -175,12 +175,12 @@ func (x *Index) Messages(filter string) []packets.Packet {
// Leaf is a child node on the tree.
type Leaf struct {
Message packets.Packet // a message which has been retained for a specific topic.
Key string // the key that was used to create the leaf.
Filter string // the path of the topic filter being matched.
Parent *Leaf // a pointer to the parent node for the leaf.
Leaves map[string]*Leaf // a map of child nodes, keyed on particle id.
Clients map[string]byte // a map of client ids subscribed to the topic.
Filter string // the path of the topic filter being matched.
Message packets.Packet // a message which has been retained for a specific topic.
}
// scanSubscribers recursively steps through a branch of leaves finding clients who

View File

@@ -0,0 +1,14 @@
package utils
// InSliceString returns true if a string exists in a slice of strings.
// This temporary and should be replaced with a function from the new
// go slices package in 1.19 when available.
// https://github.com/golang/go/issues/45955
func InSliceString(sl []string, st string) bool {
for _, v := range sl {
if st == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestInSliceString(t *testing.T) {
sl := []string{"a", "b", "c"}
require.Equal(t, true, InSliceString(sl, "b"))
sl = []string{"a", "a", "a"}
require.Equal(t, true, InSliceString(sl, "a"))
sl = []string{"a", "b", "c"}
require.Equal(t, false, InSliceString(sl, "d"))
}

View File

@@ -18,11 +18,11 @@ import (
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
system *system.Info // pointers to the server data.
address string // the network address to bind to.
listen *http.Server // the http server.
end int64 // ensure the close methods are only called once.}
end uint32 // ensure the close methods are only called once.}
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -38,10 +38,10 @@ type Listener interface {
// Listeners contains the network listeners for the broker.
type Listeners struct {
sync.RWMutex
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
system *system.Info // pointers to system info.
sync.RWMutex
}
// New returns a new instance of Listeners.

View File

@@ -22,11 +22,11 @@ func MockEstablisher(id string, c net.Conn, ac auth.Controller) error {
type MockListener struct {
sync.RWMutex
id string // the id of the listener.
Config *Config // configuration for the listener.
address string // the network address the listener binds to.
Listening bool // indiciate the listener is listening.
Serving bool // indicate the listener is serving.
Config *Config // configuration for the listener.
done chan bool // indicate the listener is done.
Serving bool // indicate the listener is serving.
Listening bool // indiciate the listener is listening.
ErrListen bool // throw an error on listen.
}
@@ -44,11 +44,8 @@ func (l *MockListener) Serve(establisher EstablishFunc) {
l.Lock()
l.Serving = true
l.Unlock()
for {
select {
case <-l.done:
return
}
for range l.done {
return
}
}

View File

@@ -14,11 +14,11 @@ import (
type TCP struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
protocol string // the TCP protocol to use.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
end int64 // ensure the close methods are only called once.
config *Config // configuration values for the listener.
end uint32 // ensure the close methods are only called once.
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
@@ -63,10 +63,12 @@ func (l *TCP) Listen(s *system.Info) error {
var err error
if l.config.TLS != nil && len(l.config.TLS.Certificate) > 0 && len(l.config.TLS.PrivateKey) > 0 {
cert, err := tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
var cert tls.Certificate
cert, err = tls.X509KeyPair(l.config.TLS.Certificate, l.config.TLS.PrivateKey)
if err != nil {
return err
}
l.listen, err = tls.Listen(l.protocol, l.address, &tls.Config{
Certificates: []tls.Certificate{cert},
})
@@ -84,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFunc) {
for {
if atomic.LoadInt64(&l.end) == 1 {
if atomic.LoadUint32(&l.end) == 1 {
return
}
@@ -93,7 +95,7 @@ func (l *TCP) Serve(establish EstablishFunc) {
return
}
if atomic.LoadInt64(&l.end) == 0 {
if atomic.LoadUint32(&l.end) == 0 {
go establish(l.id, conn, l.config.Auth)
}
}
@@ -104,8 +106,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}

View File

@@ -17,12 +17,13 @@ import (
)
var (
ErrInvalidMessage = errors.New("Message type not binary")
ErrInvalidMessage = errors.New("message type not binary")
// wsUpgrader is used to upgrade the incoming http/tcp connection to a
// websocket compliant connection.
wsUpgrader = &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool { return true },
}
)
@@ -30,11 +31,11 @@ var (
type Websocket struct {
sync.RWMutex
id string // the internal id of the listener.
config *Config // configuration values for the listener.
address string // the network address to bind to.
config *Config // configuration values for the listener.
listen *http.Server // an http server for serving websocket connections.
end int64 // ensure the close methods are only called once.
establish EstablishFunc // the server's establish conection handler.
end uint32 // ensure the close methods are only called once.
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
@@ -161,8 +162,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
l.Lock()
defer l.Unlock()
if atomic.LoadInt64(&l.end) == 0 {
atomic.StoreInt64(&l.end, 1)
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)

View File

@@ -53,50 +53,50 @@ type Subscription struct {
// Message contains the details of a retained or inflight message.
type Message struct {
ID string // the storage key.
T string // the type of the stored data.
Client string // the id of the client who sent the message (if inflight).
FixedHeader FixedHeader // the header properties of the message.
PacketID uint16 // the unique id of the packet (if inflight).
TopicName string // the topic the message was sent to (if retained).
Payload []byte // the message payload (if retained).
FixedHeader FixedHeader // the header properties of the message.
T string // the type of the stored data.
ID string // the storage key.
Client string // the id of the client who sent the message (if inflight).
TopicName string // the topic the message was sent to (if retained).
Sent int64 // the last time the message was sent (for retries) in unixtime (if inflight).
Resends int // the number of times the message was attempted to be sent (if inflight).
PacketID uint16 // the unique id of the packet (if inflight).
}
// FixedHeader contains the fixed header properties of a message.
type FixedHeader struct {
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Dup bool // indicates if the packet was already sent at an earlier time.
Qos byte // indicates the quality of service expected.
Retain bool // whether the message should be retained.
Remaining int // the number of remaining bytes in the payload.
Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte // indicates the quality of service expected.
Dup bool // indicates if the packet was already sent at an earlier time.
Retain bool // whether the message should be retained.
}
// Client contains client data that can be persistently stored.
type Client struct {
LWT LWT // the last-will-and-testament message for the client.
Username []byte // the username the client authenticated with.
ID string // the storage key.
ClientID string // the id of the client.
T string // the type of the stored data.
Listener string // the last known listener id for the client
Username []byte // the username the client authenticated with.
LWT LWT // the last-will-and-testament message for the client.
}
// LWT contains details about a clients LWT payload.
type LWT struct {
Topic string // the topic the will message shall be sent to.
Message []byte // the message that shall be sent when the client disconnects.
Topic string // the topic the will message shall be sent to.
Qos byte // the quality of service desired.
Retain bool // indicates whether the will message should be retained
}
// MockStore is a mock storage backend for testing.
type MockStore struct {
Fail map[string]bool // issue errors for different methods.
FailOpen bool // error on open.
Closed bool // indicate mock store is closed.
Opened bool // indicate mock store is open.
Fail map[string]bool // issue errors for different methods.
}
// Open opens the storage instance.

View File

@@ -14,6 +14,7 @@ import (
"github.com/mochi-co/mqtt/server/internal/clients"
"github.com/mochi-co/mqtt/server/internal/packets"
"github.com/mochi-co/mqtt/server/internal/topics"
"github.com/mochi-co/mqtt/server/internal/utils"
"github.com/mochi-co/mqtt/server/listeners"
"github.com/mochi-co/mqtt/server/listeners/auth"
"github.com/mochi-co/mqtt/server/persistence"
@@ -21,7 +22,7 @@ import (
)
const (
Version = "1.0.2" // the server version.
Version = "1.1.1" // the server version.
)
var (
@@ -44,16 +45,16 @@ var (
// Server is an MQTT broker server. It should be created with server.New()
// in order to ensure all the internal fields are correctly populated.
type Server struct {
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
Store persistence.Store // a persistent storage backend if desired.
Listeners *listeners.Listeners // listeners are network interfaces which listen for new connections.
Clients *clients.Clients // clients which are known to the broker.
Topics *topics.Index // an index of topic filter subscriptions and retained messages.
System *system.Info // values about the server commonly found in $SYS topics.
Store persistence.Store // a persistent storage backend if desired.
done chan bool // indicate that the server is ending.
bytepool circ.BytesPool // a byte pool for incoming and outgoing packets.
bytepool *circ.BytesPool // a byte pool for incoming and outgoing packets.
sysTicker *time.Ticker // the interval ticker for sending updating $SYS topics.
inline inlineMessages // channels for direct publishing.
Events events.Events // overrideable event hooks.
done chan bool // indicate that the server is ending.
}
// inlineMessages contains channels for handling inline (direct) publishing.
@@ -207,7 +208,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
var sessionPresent bool
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
existing.Lock()
if atomic.LoadInt64(&existing.State.Done) == 1 {
if atomic.LoadUint32(&existing.State.Done) == 1 {
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
}
existing.Stop()
@@ -220,7 +221,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
}
}
} else {
cl.Inflight = existing.Inflight // Inherit from existing session.
cl.Inflight = existing.Inflight // Take address of existing session.
cl.Subscriptions = existing.Subscriptions
sessionPresent = true
}
@@ -258,6 +259,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
})
}
if s.Events.OnConnect != nil {
s.Events.OnConnect(events.FromClient(cl), events.Packet(pk))
}
err = cl.Read(s.processPacket)
if err != nil {
s.closeClient(cl, true)
@@ -269,6 +274,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
atomic.AddInt64(&s.System.ClientsConnected, -1)
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
if s.Events.OnDisconnect != nil {
s.Events.OnDisconnect(events.FromClient(cl), err)
}
return err
}
@@ -412,7 +421,7 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
// if an OnMessage hook exists, potentially modify the packet.
if s.Events.OnMessage != nil {
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
if pkx, err := s.Events.OnMessage(events.FromClient(cl), events.Packet(pk)); err == nil {
pk = packets.Packet(pkx)
}
}
@@ -449,6 +458,14 @@ func (s *Server) retainMessage(pk packets.Packet) {
func (s *Server) publishToSubscribers(pk packets.Packet) {
for id, qos := range s.Topics.Subscribers(pk.TopicName) {
if client, ok := s.Clients.Get(id); ok {
// If the AllowClients value is set, only deliver the packet if the subscribed
// client exists in the AllowClients value. For use with the OnMessage event hook
// in cases where you want to publish messages to clients selectively.
if pk.AllowClients != nil && !utils.InSliceString(pk.AllowClients, id) {
continue
}
out := pk.PublishCopy()
if qos > out.FixedHeader.Qos { // Inherit higher desired qos values.
out.FixedHeader.Qos = qos

View File

@@ -36,6 +36,15 @@ func setupClient() (s *Server, cl *clients.Client, r net.Conn, w net.Conn) {
return
}
func setupServerClient(s *Server) (cl *clients.Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = clients.NewClient(w, circ.NewReader(256, 8), circ.NewWriter(256, 8), s.System)
cl.ID = "mochi"
cl.AC = new(auth.Allow)
cl.Start()
return
}
func TestNew(t *testing.T) {
s := New()
require.NotNil(t, s)
@@ -220,6 +229,203 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
w.Close()
}
func TestServerEventOnConnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedPacket events.Packet
s.Events.OnConnect = func(cl events.Client, pk events.Packet) {
hookedClient = cl
hookedPacket = pk
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
time.Sleep(10 * time.Millisecond)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, events.Packet(packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Connect,
Remaining: 17,
},
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
ProtocolVersion: 4,
CleanSession: true,
Keepalive: 45,
ClientIdentifier: "mochi",
}), hookedPacket)
}
func TestServerEventOnDisconnect(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
// Receive the Connack
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
require.Empty(t, clw.Subscriptions)
w.Close()
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
require.Equal(t, nil, hookedErr)
}
func TestServerEventOnDisconnectOnError(t *testing.T) {
r, w := net.Pipe()
s, cl, _, _ := setupClient()
s.Clients.Add(cl)
var hookedClient events.Client
var hookedErr error
s.Events.OnDisconnect = func(cl events.Client, err error) {
hookedClient = cl
hookedErr = err
}
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 17, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 5, // Client ID - MSB+LSB
'm', 'o', 'c', 'h', 'i', // Client ID
})
w.Write([]byte{0, 0})
}()
// Receive the Connack
go func() {
_, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
}()
clw, ok := s.Clients.Get("mochi")
require.Equal(t, true, ok)
clw.Stop()
errx := <-o
require.Error(t, errx)
require.Equal(t, "No valid packet available; 0", errx.Error())
require.Equal(t, errx, hookedErr)
require.Equal(t, events.Client{
ID: "mochi",
Listener: "tcp",
}, hookedClient)
}
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
s := New()
@@ -433,6 +639,10 @@ func TestServerEstablishConnectionReadPacketErr(t *testing.T) {
require.Error(t, errx)
}
func TestServerOnDisconnectErr(t *testing.T) {
}
func TestServerWriteClient(t *testing.T) {
s, cl, r, w := setupClient()
cl.ID = "mochi"
@@ -558,7 +768,7 @@ func TestServerProcessPublishQoS1Retain(t *testing.T) {
cl1.ID = "mochi1"
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -687,7 +897,7 @@ func TestServerProcessPublishOfflineQueuing(t *testing.T) {
s.Clients.Add(cl1)
// Start and stop the receiver client
_, cl2, _, _ := setupClient()
cl2, _, _ := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
s.Topics.Subscribe("qos0", cl2.ID, 0)
@@ -922,7 +1132,7 @@ func TestServerPublishInlineSysTopicError(t *testing.T) {
require.Equal(t, int64(0), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessage(t *testing.T) {
func TestServerEventOnMessage(t *testing.T) {
s, cl1, r1, w1 := setupClient()
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/+", cl1.ID, 0)
@@ -1071,6 +1281,86 @@ func TestServerProcessPublishHookOnMessageModifyError(t *testing.T) {
require.Equal(t, int64(14), s.System.BytesSent)
}
func TestServerProcessPublishHookOnMessageAllowClients(t *testing.T) {
s, cl1, r1, w1 := setupClient()
cl1.ID = "allowed"
s.Clients.Add(cl1)
s.Topics.Subscribe("a/b/c", cl1.ID, 0)
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "not_allowed"
s.Clients.Add(cl2)
s.Topics.Subscribe("a/b/c", cl2.ID, 0)
s.Topics.Subscribe("d/e/f", cl2.ID, 0)
s.Events.OnMessage = func(cl events.Client, pk events.Packet) (events.Packet, error) {
hookedPacket := pk
if pk.TopicName == "a/b/c" {
hookedPacket.AllowClients = []string{"allowed"}
}
return hookedPacket, nil
}
ack1 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r1)
if err != nil {
panic(err)
}
ack1 <- buf
}()
ack2 := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r2)
if err != nil {
panic(err)
}
ack2 <- buf
}()
pk1 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "a/b/c",
Payload: []byte("hello"),
}
err := s.processPacket(cl1, pk1)
require.NoError(t, err)
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "d/e/f",
Payload: []byte("a"),
}
err = s.processPacket(cl1, pk2)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
w1.Close()
w2.Close()
require.Equal(t, []byte{
byte(packets.Publish << 4), 12,
0, 5,
'a', '/', 'b', '/', 'c',
'h', 'e', 'l', 'l', 'o',
}, <-ack1)
require.Equal(t, []byte{
byte(packets.Publish << 4), 8,
0, 5,
'd', '/', 'e', '/', 'f',
'a',
}, <-ack2)
require.Equal(t, int64(24), s.System.BytesSent)
}
func TestServerProcessPuback(t *testing.T) {
s, cl, _, _ := setupClient()
cl.Inflight.Set(11, clients.InflightMessage{Packet: packets.Packet{PacketID: 11}, Sent: 0})
@@ -1443,7 +1733,7 @@ func TestServerCloseClientLWT(t *testing.T) {
}
s.Clients.Add(cl1)
_, cl2, r2, w2 := setupClient()
cl2, r2, w2 := setupServerClient(s)
cl2.ID = "mochi2"
s.Clients.Add(cl2)
@@ -1566,14 +1856,14 @@ func TestServerLoadSubscriptions(t *testing.T) {
s.Clients.Add(cl)
subs := []persistence.Subscription{
persistence.Subscription{
{
ID: "test:a/b/c",
Client: "test",
Filter: "a/b/c",
QoS: 1,
T: persistence.KSubscription,
},
persistence.Subscription{
{
ID: "test:d/e/f",
Client: "test",
Filter: "d/e/f",
@@ -1623,7 +1913,7 @@ func TestServerLoadInflight(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_if_0",
T: persistence.KInflight,
Client: "client1",
@@ -1633,7 +1923,7 @@ func TestServerLoadInflight(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_if_100",
T: persistence.KInflight,
Client: "client1",
@@ -1668,7 +1958,7 @@ func TestServerLoadRetained(t *testing.T) {
require.NotNil(t, s)
msgs := []persistence.Message{
persistence.Message{
{
ID: "client1_ret_200",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{
@@ -1680,7 +1970,7 @@ func TestServerLoadRetained(t *testing.T) {
Sent: 100,
Resends: 0,
},
persistence.Message{
{
ID: "client1_ret_300",
T: persistence.KRetained,
FixedHeader: persistence.FixedHeader{

View File

@@ -0,0 +1,21 @@
package system
import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
func TestInfoAlignment(t *testing.T) {
typ := reflect.TypeOf(Info{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
switch f.Type.Kind() {
case reflect.Int64, reflect.Uint64:
require.Equalf(t, uintptr(0), f.Offset%8,
"%s requires 64-bit alignment for atomic: offset %d",
f.Name, f.Offset)
}
}
}