mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-03 00:24:13 +08:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86e0a5827e | ||
|
|
06c399b606 | ||
|
|
ed117f67a1 | ||
|
|
880a3299e1 | ||
|
|
1c408d05be | ||
|
|
fce495f83e | ||
|
|
471ca00a64 | ||
|
|
a2c0749640 | ||
|
|
37293aeecf | ||
|
|
7a2d4db6a4 | ||
|
|
03d2a8bc82 | ||
|
|
4b51e5c7d1 | ||
|
|
d15ad682bf | ||
|
|
130ffcbb53 | ||
|
|
33cf2f991b | ||
|
|
a360ea6a6c | ||
|
|
ae3aa0d3fa | ||
|
|
811ae0e1be |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
cmd/mqtt
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
server/persistence/bolt/testbolt.db
|
||||
|
||||
29
README.md
29
README.md
@@ -25,9 +25,11 @@ MQTT stands for MQ Telemetry Transport. It is a publish/subscribe, extremely sim
|
||||
- Interfaces for Client Authentication and Topic access control.
|
||||
- Bolt-backed persistence and storage interfaces.
|
||||
- Directly Publishing from embedding service (`s.Publish(topic, message, retain)`).
|
||||
- Basic Event Hooks (currently `onMessage`)
|
||||
- Basic Event Hooks (currently `OnMessage`, `OnConnect`, `OnDisconnect`).
|
||||
- ARM32 Compatible.
|
||||
|
||||
#### Roadmap
|
||||
- Please open an issue to request new features or event hooks.
|
||||
- MQTT v5 compatibility?
|
||||
|
||||
#### Using the Broker
|
||||
@@ -105,8 +107,30 @@ err := server.AddListener(tcp, &listeners.Config{
|
||||
#### Event Hooks
|
||||
Some basic Event Hooks have been added, allowing you to call your own functions when certain events occur. The execution of the functions are blocking - if necessary, please handle goroutines within the embedding service.
|
||||
|
||||
Working examples can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
|
||||
##### OnConnect
|
||||
`server.Events.OnConnect` is called when a client successfully connects to the broker. The method receives the connect packet and the id and connection type for the client who connected.
|
||||
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/server/events"
|
||||
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnDisconnect
|
||||
`server.Events.OnDisconnect` is called when a client disconnects to the broker. If the client disconnected abnormally, the reason is indicated in the `err` error parameter.
|
||||
|
||||
```go
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
```
|
||||
|
||||
##### OnMessage
|
||||
`server.Events.OnMessage` is called when a Publish packet is received. The function receives the published message and information about the client who published it. This function will block message dispatching until it returns.
|
||||
`server.Events.OnMessage` is called when a Publish packet is received. The method receives the published message and information about the client who published it.
|
||||
|
||||
> This hook is only triggered when a message is received by clients. It is not triggered when using the direct `server.Publish` method.
|
||||
|
||||
@@ -126,7 +150,6 @@ server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.P
|
||||
|
||||
The OnMessage hook can also be used to selectively only deliver messages to one or more clients based on their id, using the `AllowClients []string` field on the packet structure.
|
||||
|
||||
A working example can be found in the `examples/events` folder. Please open an issue if there is a particular event hook you are interested in!
|
||||
|
||||
#### Direct Publishing
|
||||
When the broker is being embedded in a larger codebase, it can be useful to be able to publish messages directly to clients without having to implement a loopback TCP connection with an MQTT client. The `Publish` method allows you to inject publish messages directly into a queue to be delivered to any clients with matching topic filters. The `Retain` flag is supported.
|
||||
|
||||
@@ -44,6 +44,16 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add OnConnect Event Hook
|
||||
server.Events.OnConnect = func(cl events.Client, pk events.Packet) {
|
||||
fmt.Printf("<< OnConnect client connected %s: %+v\n", cl.ID, pk)
|
||||
}
|
||||
|
||||
// Add OnDisconnect Event Hook
|
||||
server.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
fmt.Printf("<< OnDisconnect client dicconnected %s: %v\n", cl.ID, err)
|
||||
}
|
||||
|
||||
// Add OnMessage Event Hook
|
||||
server.Events.OnMessage = func(cl events.Client, pk events.Packet) (pkx events.Packet, err error) {
|
||||
pkx = pk
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
)
|
||||
|
||||
type Events struct {
|
||||
OnMessage // published message receieved.
|
||||
OnMessage // published message receieved.
|
||||
OnConnect // client connected.
|
||||
OnDisconnect // client disconnected.
|
||||
}
|
||||
|
||||
type Packet packets.Packet
|
||||
@@ -17,7 +19,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// FromClient returns an event client from a client.
|
||||
func FromClient(cl clients.Client) Client {
|
||||
func FromClient(cl *clients.Client) Client {
|
||||
return Client{
|
||||
ID: cl.ID,
|
||||
Listener: cl.Listener,
|
||||
@@ -34,3 +36,11 @@ func FromClient(cl clients.Client) Client {
|
||||
// This function will block message dispatching until it returns. To minimise this,
|
||||
// have the function open a new goroutine on the embedding side.
|
||||
type OnMessage func(Client, Packet) (Packet, error)
|
||||
|
||||
// OnConnect is called when a client successfully connects to the broker.
|
||||
type OnConnect func(Client, Packet)
|
||||
|
||||
// OnDisconnect is called when a client disconnects to the broker. An error value
|
||||
// is passed to the function if the client disconnected abnormally, otherwise it
|
||||
// will be nil on a normal disconnect.
|
||||
type OnDisconnect func(Client, error)
|
||||
|
||||
@@ -24,12 +24,13 @@ type Buffer struct {
|
||||
block int // the size of the R/W block.
|
||||
buf []byte // the bytes buffer.
|
||||
tmp []byte // a temporary buffer.
|
||||
_ int32 // align the next fields to an 8-byte boundary for atomic access.
|
||||
head int64 // the current position in the sequence - a forever increasing index.
|
||||
tail int64 // the committed position in the sequence - a forever increasing index.
|
||||
rcond *sync.Cond // the sync condition for the buffer reader.
|
||||
wcond *sync.Cond // the sync condition for the buffer writer.
|
||||
done int64 // indicates that the buffer is closed.
|
||||
State int64 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
done uint32 // indicates that the buffer is closed.
|
||||
State uint32 // indicates whether the buffer is reading from (1) or writing to (2).
|
||||
}
|
||||
|
||||
// NewBuffer returns a new instance of buffer. You should call NewReader or
|
||||
@@ -127,7 +128,7 @@ func (b *Buffer) awaitEmpty(n int) error {
|
||||
// then wait until tail has moved.
|
||||
b.rcond.L.Lock()
|
||||
for !b.checkEmpty(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.rcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -146,7 +147,7 @@ func (b *Buffer) awaitFilled(n int) error {
|
||||
// the forever-incrementing tail and head integers.
|
||||
b.wcond.L.Lock()
|
||||
for !b.checkFilled(n) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
b.wcond.L.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
@@ -196,7 +197,7 @@ func (b *Buffer) CapDelta() int {
|
||||
|
||||
// Stop signals the buffer to stop processing.
|
||||
func (b *Buffer) Stop() {
|
||||
atomic.StoreInt64(&b.done, 1)
|
||||
atomic.StoreUint32(&b.done, 1)
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -50,6 +51,18 @@ func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestAtomicAlignment(t *testing.T) {
|
||||
var b Buffer
|
||||
|
||||
offset := unsafe.Offsetof(b.head)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"head requires 64-bit alignment for atomic: offset %d", offset)
|
||||
|
||||
offset = unsafe.Offsetof(b.tail)
|
||||
require.Equalf(t, uintptr(0), offset%8,
|
||||
"tail requires 64-bit alignment for atomic: offset %d", offset)
|
||||
}
|
||||
|
||||
func TestGetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
tail, head := buf.GetPos()
|
||||
@@ -140,7 +153,7 @@ func TestAwaitFilledEnded(t *testing.T) {
|
||||
o <- buf.awaitFilled(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -190,7 +203,7 @@ func TestAwaitEmptyEnded(t *testing.T) {
|
||||
o <- buf.awaitEmpty(4)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -280,7 +293,7 @@ func TestCommitTailEnded(t *testing.T) {
|
||||
o <- buf.CommitTail(5)
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
@@ -300,5 +313,5 @@ func TestCapDelta(t *testing.T) {
|
||||
func TestStop(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
buf.Stop()
|
||||
require.Equal(t, int64(1), buf.done)
|
||||
require.Equal(t, uint32(1), buf.done)
|
||||
}
|
||||
|
||||
@@ -32,10 +32,10 @@ func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
atomic.StoreInt64(&b.State, 1)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
atomic.StoreUint32(&b.State, 1)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
if atomic.LoadUint32(&b.done) == 1 {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestReadFromWrap(t *testing.T) {
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
go func() {
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
@@ -116,7 +116,7 @@ func TestReadEnded(t *testing.T) {
|
||||
o <- err
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
@@ -32,10 +32,10 @@ func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||
atomic.StoreInt64(&b.State, 2)
|
||||
defer atomic.StoreInt64(&b.State, 0)
|
||||
atomic.StoreUint32(&b.State, 2)
|
||||
defer atomic.StoreUint32(&b.State, 0)
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&b.done) == 1 && b.CapDelta() == 0 {
|
||||
return total, io.EOF
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TestWriteTo(t *testing.T) {
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
atomic.StoreUint32(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
@@ -74,7 +74,7 @@ func (cl *Clients) GetByListener(id string) []*Client {
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
cl.RLock()
|
||||
for _, v := range cl.internal {
|
||||
if v.Listener == id && atomic.LoadInt64(&v.State.Done) == 0 {
|
||||
if v.Listener == id && atomic.LoadUint32(&v.State.Done) == 0 {
|
||||
clients = append(clients, v)
|
||||
}
|
||||
}
|
||||
@@ -104,7 +104,7 @@ type Client struct {
|
||||
|
||||
// State tracks the state of the client.
|
||||
type State struct {
|
||||
Done int64 // atomic counter which indicates that the client has closed.
|
||||
Done uint32 // atomic counter which indicates that the client has closed.
|
||||
started *sync.WaitGroup // tracks the goroutines which have been started.
|
||||
endedW *sync.WaitGroup // tracks when the writer has ended.
|
||||
endedR *sync.WaitGroup // tracks when the reader has ended.
|
||||
@@ -240,7 +240,7 @@ func (cl *Client) Start() {
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop() {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ func (cl *Client) Stop() {
|
||||
cl.conn.Close()
|
||||
|
||||
cl.State.endedR.Wait()
|
||||
atomic.StoreInt64(&cl.State.Done, 1)
|
||||
atomic.StoreUint32(&cl.State.Done, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
// an error is encountered (or the connection is closed).
|
||||
func (cl *Client) Read(packetHandler func(*Client, packets.Packet) error) error {
|
||||
for {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 && cl.r.CapDelta() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -391,7 +391,7 @@ func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err er
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) (n int, err error) {
|
||||
if atomic.LoadInt64(&cl.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&cl.State.Done) == 1 {
|
||||
return 0, ErrConnectionClosed
|
||||
}
|
||||
|
||||
|
||||
@@ -316,8 +316,8 @@ func TestClientStart(t *testing.T) {
|
||||
cl.Start()
|
||||
defer cl.Stop()
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.r.State))
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.w.State))
|
||||
require.Equal(t, uint32(1), atomic.LoadUint32(&cl.r.State))
|
||||
require.Equal(t, uint32(2), atomic.LoadUint32(&cl.w.State))
|
||||
}
|
||||
|
||||
func BenchmarkClientStart(b *testing.B) {
|
||||
|
||||
@@ -16,7 +16,7 @@ type FixedHeader struct {
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, fh.Remaining)
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// decode extracts the specification bits from the header byte.
|
||||
@@ -44,7 +44,7 @@ func (fh *FixedHeader) Decode(headerByte byte) error {
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(buf *bytes.Buffer, length int) {
|
||||
func encodeLength(buf *bytes.Buffer, length int64) {
|
||||
for {
|
||||
digit := byte(length % 128)
|
||||
length /= 128
|
||||
|
||||
@@ -192,7 +192,7 @@ func BenchmarkFixedHeaderDecode(b *testing.B) {
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
tt := []struct {
|
||||
have int
|
||||
have int64
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ type HTTPStats struct {
|
||||
system *system.Info // pointers to the server data.
|
||||
address string // the network address to bind to.
|
||||
listen *http.Server // the http server.
|
||||
end int64 // ensure the close methods are only called once.}
|
||||
end uint32 // ensure the close methods are only called once.}
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
@@ -98,9 +98,7 @@ func (l *HTTPStats) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
||||
@@ -18,7 +18,7 @@ type TCP struct {
|
||||
protocol string // the TCP protocol to use.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
@@ -86,7 +86,7 @@ func (l *TCP) Listen(s *system.Info) error {
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFunc) {
|
||||
for {
|
||||
if atomic.LoadInt64(&l.end) == 1 {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ func (l *TCP) Serve(establish EstablishFunc) {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go establish(l.id, conn, l.config.Auth)
|
||||
}
|
||||
}
|
||||
@@ -106,8 +106,7 @@ func (l *TCP) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ type Websocket struct {
|
||||
config *Config // configuration values for the listener.
|
||||
address string // the network address to bind to.
|
||||
listen *http.Server // an http server for serving websocket connections.
|
||||
end int64 // ensure the close methods are only called once.
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
establish EstablishFunc // the server's establish conection handler.
|
||||
}
|
||||
|
||||
@@ -162,8 +162,7 @@ func (l *Websocket) Close(closeClients CloseFunc) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.LoadInt64(&l.end) == 0 {
|
||||
atomic.StoreInt64(&l.end, 1)
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.3" // the server version.
|
||||
Version = "1.1.0" // the server version.
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -208,7 +208,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
var sessionPresent bool
|
||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
if atomic.LoadInt64(&existing.State.Done) == 1 {
|
||||
if atomic.LoadUint32(&existing.State.Done) == 1 {
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, -1)
|
||||
}
|
||||
existing.Stop()
|
||||
@@ -259,6 +259,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
})
|
||||
}
|
||||
|
||||
if s.Events.OnConnect != nil {
|
||||
s.Events.OnConnect(events.FromClient(cl), events.Packet(pk))
|
||||
}
|
||||
|
||||
err = cl.Read(s.processPacket)
|
||||
if err != nil {
|
||||
s.closeClient(cl, true)
|
||||
@@ -270,6 +274,10 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
atomic.AddInt64(&s.System.ClientsConnected, -1)
|
||||
atomic.AddInt64(&s.System.ClientsDisconnected, 1)
|
||||
|
||||
if s.Events.OnDisconnect != nil {
|
||||
s.Events.OnDisconnect(events.FromClient(cl), err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -413,7 +421,7 @@ func (s *Server) processPublish(cl *clients.Client, pk packets.Packet) error {
|
||||
|
||||
// if an OnMessage hook exists, potentially modify the packet.
|
||||
if s.Events.OnMessage != nil {
|
||||
if pkx, err := s.Events.OnMessage(events.FromClient(*cl), events.Packet(pk)); err == nil {
|
||||
if pkx, err := s.Events.OnMessage(events.FromClient(cl), events.Packet(pk)); err == nil {
|
||||
pk = packets.Packet(pkx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +229,203 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
||||
w.Close()
|
||||
}
|
||||
|
||||
func TestServerEventOnConnect(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedPacket events.Packet
|
||||
s.Events.OnConnect = func(cl events.Client, pk events.Packet) {
|
||||
hookedClient = cl
|
||||
hookedPacket = pk
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
w.Write([]byte{byte(packets.Disconnect << 4), 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.NoError(t, errx)
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Connack << 4), 2,
|
||||
0, packets.Accepted,
|
||||
}, <-recv)
|
||||
require.Empty(t, clw.Subscriptions)
|
||||
|
||||
w.Close()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
require.Equal(t, events.Packet(packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Connect,
|
||||
Remaining: 17,
|
||||
},
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
ProtocolVersion: 4,
|
||||
CleanSession: true,
|
||||
Keepalive: 45,
|
||||
ClientIdentifier: "mochi",
|
||||
}), hookedPacket)
|
||||
|
||||
}
|
||||
|
||||
func TestServerEventOnDisconnect(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedErr error
|
||||
s.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
hookedClient = cl
|
||||
hookedErr = err
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
w.Write([]byte{byte(packets.Disconnect << 4), 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.NoError(t, errx)
|
||||
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Connack << 4), 2,
|
||||
0, packets.Accepted,
|
||||
}, <-recv)
|
||||
require.Empty(t, clw.Subscriptions)
|
||||
|
||||
w.Close()
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
require.Equal(t, nil, hookedErr)
|
||||
}
|
||||
|
||||
func TestServerEventOnDisconnectOnError(t *testing.T) {
|
||||
r, w := net.Pipe()
|
||||
s, cl, _, _ := setupClient()
|
||||
s.Clients.Add(cl)
|
||||
|
||||
var hookedClient events.Client
|
||||
var hookedErr error
|
||||
s.Events.OnDisconnect = func(cl events.Client, err error) {
|
||||
hookedClient = cl
|
||||
hookedErr = err
|
||||
}
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
w.Write([]byte{
|
||||
byte(packets.Connect << 4), 17, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
2, // Packet Flags - clean session
|
||||
0, 45, // Keepalive
|
||||
0, 5, // Client ID - MSB+LSB
|
||||
'm', 'o', 'c', 'h', 'i', // Client ID
|
||||
})
|
||||
|
||||
w.Write([]byte{0, 0})
|
||||
}()
|
||||
|
||||
// Receive the Connack
|
||||
go func() {
|
||||
_, err := ioutil.ReadAll(w)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
clw, ok := s.Clients.Get("mochi")
|
||||
require.Equal(t, true, ok)
|
||||
clw.Stop()
|
||||
|
||||
errx := <-o
|
||||
require.Error(t, errx)
|
||||
require.Equal(t, "No valid packet available; 0", errx.Error())
|
||||
fmt.Println(hookedErr)
|
||||
require.Equal(t, errx, hookedErr)
|
||||
|
||||
require.Equal(t, events.Client{
|
||||
ID: "mochi",
|
||||
Listener: "tcp",
|
||||
}, hookedClient)
|
||||
|
||||
}
|
||||
|
||||
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
@@ -442,6 +639,10 @@ func TestServerEstablishConnectionReadPacketErr(t *testing.T) {
|
||||
require.Error(t, errx)
|
||||
}
|
||||
|
||||
func TestServerOnDisconnectErr(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestServerWriteClient(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
cl.ID = "mochi"
|
||||
@@ -931,7 +1132,7 @@ func TestServerPublishInlineSysTopicError(t *testing.T) {
|
||||
require.Equal(t, int64(0), s.System.BytesSent)
|
||||
}
|
||||
|
||||
func TestServerProcessPublishHookOnMessage(t *testing.T) {
|
||||
func TestServerEventOnMessage(t *testing.T) {
|
||||
s, cl1, r1, w1 := setupClient()
|
||||
s.Clients.Add(cl1)
|
||||
s.Topics.Subscribe("a/b/+", cl1.ID, 0)
|
||||
|
||||
21
server/system/system_test.go
Normal file
21
server/system/system_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInfoAlignment(t *testing.T) {
|
||||
typ := reflect.TypeOf(Info{})
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
switch f.Type.Kind() {
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
require.Equalf(t, uintptr(0), f.Offset%8,
|
||||
"%s requires 64-bit alignment for atomic: offset %d",
|
||||
f.Name, f.Offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user