mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-24 16:40:26 +08:00
bytes buffer to pool
This commit is contained in:
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@@ -11,9 +12,14 @@ import (
|
||||
|
||||
"github.com/mochi-co/mqtt"
|
||||
"github.com/mochi-co/mqtt/internal/listeners"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
|
||||
1
go.mod
1
go.mod
@@ -6,6 +6,7 @@ require (
|
||||
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
|
||||
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
|
||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164
|
||||
github.com/pkg/profile v1.4.0
|
||||
github.com/rs/xid v1.2.1
|
||||
github.com/stretchr/testify v1.4.0
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -8,6 +8,8 @@ github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8 h1:BIY2BMCLHm6hE/SU
|
||||
github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8/go.mod h1:AqE7zHPhLOj61seX0vXvzpGiD9Q3Bx5LQPf/FleHKWc=
|
||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 h1:XGYo79ZRE9pQE9B5iZCYw3VLaq88PfxcdvDf9crG+dQ=
|
||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164/go.mod h1:LfBrWXdsMaDKL0ZjcbnLjeYL48Nlo1nW4MltMDYqr44=
|
||||
github.com/pkg/profile v1.4.0 h1:uCmaf4vVbWAOZz36k1hrQD7ijGRzLwaME8Am/7a4jZI=
|
||||
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultBufferSize int = 2048 // the default size of the buffer in bytes.
|
||||
DefaultBlockSize int = 128 // the default size per R/W block in bytes.
|
||||
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
|
||||
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
|
||||
|
||||
ErrOutOfRange = fmt.Errorf("Indexes out of range")
|
||||
ErrInsufficientBytes = fmt.Errorf("Insufficient bytes to return")
|
||||
@@ -43,6 +43,7 @@ func NewBuffer(size, block int) Buffer {
|
||||
if block == 0 {
|
||||
block = DefaultBlockSize
|
||||
}
|
||||
|
||||
if size < 2*block {
|
||||
size = 2 * block
|
||||
}
|
||||
@@ -52,12 +53,32 @@ func NewBuffer(size, block int) Buffer {
|
||||
mask: size - 1,
|
||||
block: block,
|
||||
buf: make([]byte, size),
|
||||
tmp: make([]byte, size),
|
||||
rcond: sync.NewCond(new(sync.Mutex)),
|
||||
wcond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBufferFromSlice returns a new instance of buffer using a
|
||||
// pre-existing byte slice.
|
||||
func NewBufferFromSlice(block int, buf []byte) Buffer {
|
||||
l := len(buf)
|
||||
|
||||
if block == 0 {
|
||||
block = DefaultBlockSize
|
||||
}
|
||||
|
||||
b := Buffer{
|
||||
size: l,
|
||||
mask: l - 1,
|
||||
block: block,
|
||||
buf: buf,
|
||||
rcond: sync.NewCond(new(sync.Mutex)),
|
||||
wcond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Get will return the tail and head positions of the buffer.
|
||||
// This method is for use with testing.
|
||||
func (b *Buffer) GetPos() (int64, int64) {
|
||||
|
||||
@@ -36,6 +36,20 @@ func TestNewBufferUndersize(t *testing.T) {
|
||||
require.Equal(t, DefaultBlockSize, buf.block)
|
||||
}
|
||||
|
||||
func TestNewBufferFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewBufferFromSlice(0, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestGetPos(t *testing.T) {
|
||||
buf := NewBuffer(16, 4)
|
||||
tail, head := buf.GetPos()
|
||||
|
||||
32
internal/circ/pool.go
Normal file
32
internal/circ/pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BytesPool is a pool of []byte
|
||||
type BytesPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// NewBytesPool returns a sync.pool of []byte
|
||||
func NewBytesPool(n int) BytesPool {
|
||||
return BytesPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, n)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a pooled bytes.Buffer.
|
||||
func (b BytesPool) Get() []byte {
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
// Put puts the byte slice back into the pool.
|
||||
func (b BytesPool) Put(x []byte) {
|
||||
x = x[:0]
|
||||
b.pool.Put(x)
|
||||
}
|
||||
46
internal/circ/pool_test.go
Normal file
46
internal/circ/pool_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package circ
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBytesPool(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
require.NotNil(t, bpool.pool)
|
||||
}
|
||||
|
||||
func BenchmarkNewBytesPool(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewBytesPool(256)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBytesPoolGet(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
|
||||
require.Equal(t, make([]byte, 256), buf)
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolGet(b *testing.B) {
|
||||
bpool := NewBytesPool(256)
|
||||
for n := 0; n < b.N; n++ {
|
||||
bpool.Get()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBytesPoolPut(t *testing.T) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
bpool.Put(buf)
|
||||
}
|
||||
|
||||
func BenchmarkBytesPoolPut(b *testing.B) {
|
||||
bpool := NewBytesPool(256)
|
||||
buf := bpool.Get()
|
||||
for n := 0; n < b.N; n++ {
|
||||
bpool.Put(buf)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,6 @@ package circ
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
dbg "github.com/mochi-co/debug"
|
||||
)
|
||||
|
||||
// Reader is a circular buffer for reading data from an io.Reader.
|
||||
@@ -12,7 +10,7 @@ type Reader struct {
|
||||
Buffer
|
||||
}
|
||||
|
||||
// NewReader returns a pointer to a new Circular Reader.
|
||||
// NewReader returns a new Circular Reader.
|
||||
func NewReader(size, block int) *Reader {
|
||||
b := NewBuffer(size, block)
|
||||
b.ID = "\treader"
|
||||
@@ -21,6 +19,16 @@ func NewReader(size, block int) *Reader {
|
||||
}
|
||||
}
|
||||
|
||||
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
|
||||
// byte slice.
|
||||
func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
b.ID = "\treader"
|
||||
return &Reader{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
@@ -48,8 +56,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
end = b.size
|
||||
}
|
||||
|
||||
dbg.Println(dbg.Yellow, b.ID, "b.ReadFrom allocating", start, ":", end)
|
||||
|
||||
// Read into the buffer between the start and end indexes only.
|
||||
n, err := r.Read(b.buf[start:end])
|
||||
total += int64(n) // incr total bytes read.
|
||||
@@ -57,8 +63,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
dbg.Println(dbg.HiYellow, b.ID, "b.ReadFrom received", n, b.buf[start:start+n])
|
||||
|
||||
// Move the head forward however many bytes were read.
|
||||
atomic.AddInt64(&b.head, int64(n))
|
||||
|
||||
@@ -71,8 +75,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
// Read reads n bytes from the buffer, and will block until at n bytes
|
||||
// exist in the buffer to read.
|
||||
func (b *Buffer) Read(n int) (p []byte, err error) {
|
||||
dbg.Println(dbg.Cyan, b.ID, "b.Read waiting for", n, "bytes")
|
||||
|
||||
err = b.awaitFilled(n)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -90,7 +92,5 @@ func (b *Buffer) Read(n int) (p []byte, err error) {
|
||||
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
|
||||
}
|
||||
|
||||
dbg.Println(dbg.HiCyan, b.ID, "b.Read read", tail, next, b.tmp)
|
||||
|
||||
return b.tmp, nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,13 @@ func TestNewReader(t *testing.T) {
|
||||
require.Equal(t, block, buf.block)
|
||||
}
|
||||
|
||||
func TestNewReaderFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestReadFrom(t *testing.T) {
|
||||
buf := NewReader(16, 4)
|
||||
|
||||
|
||||
@@ -19,6 +19,16 @@ func NewWriter(size, block int) *Writer {
|
||||
}
|
||||
}
|
||||
|
||||
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
|
||||
// byte slice.
|
||||
func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||
b := NewBufferFromSlice(block, p)
|
||||
b.ID = "writer"
|
||||
return &Writer{
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||
atomic.StoreInt64(&b.State, 2)
|
||||
|
||||
@@ -22,6 +22,13 @@ func TestNewWriter(t *testing.T) {
|
||||
require.Equal(t, block, buf.block)
|
||||
}
|
||||
|
||||
func TestNewWriterFromSlice(t *testing.T) {
|
||||
b := NewBytesPool(256)
|
||||
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
|
||||
require.NotNil(t, buf.buf)
|
||||
require.Equal(t, 256, cap(buf.buf))
|
||||
}
|
||||
|
||||
func TestWriteTo(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
dbg "github.com/mochi-co/debug"
|
||||
"github.com/mochi-co/mqtt/internal/auth"
|
||||
"github.com/mochi-co/mqtt/internal/circ"
|
||||
"github.com/mochi-co/mqtt/internal/packets"
|
||||
@@ -210,27 +209,26 @@ func (cl *Client) Start() {
|
||||
|
||||
go func() {
|
||||
cl.state.started.Done()
|
||||
_, err := cl.w.WriteTo(cl.conn)
|
||||
dbg.Println(dbg.HiRed, cl.ID, "WriteTo stopped", err)
|
||||
//_, err :=
|
||||
cl.w.WriteTo(cl.conn)
|
||||
cl.state.endedW.Done()
|
||||
//cl.close()
|
||||
}()
|
||||
cl.state.endedW.Add(1)
|
||||
|
||||
go func() {
|
||||
cl.state.started.Done()
|
||||
_, err := cl.r.ReadFrom(cl.conn)
|
||||
dbg.Println(dbg.HiRed, cl.ID, "ReadFrom stopped", err)
|
||||
//_, err :=
|
||||
cl.r.ReadFrom(cl.conn)
|
||||
cl.state.endedR.Done()
|
||||
//cl.close()
|
||||
}()
|
||||
|
||||
cl.state.endedR.Add(1)
|
||||
cl.state.started.Wait()
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop() {
|
||||
dbg.Println(dbg.HiRed+"CLIENT stop called...", dbg.Underline+cl.ID)
|
||||
cl.r.Stop()
|
||||
cl.w.Stop()
|
||||
cl.state.endedW.Wait()
|
||||
@@ -241,7 +239,6 @@ func (cl *Client) Stop() {
|
||||
}
|
||||
|
||||
cl.state.endedR.Wait()
|
||||
dbg.Println(dbg.HiRed+"CLIENT stopped", dbg.Underline+cl.ID)
|
||||
}
|
||||
|
||||
// readFixedHeader reads in the values of the next packet's fixed header.
|
||||
|
||||
135
mqtt.go
135
mqtt.go
@@ -23,28 +23,11 @@ var (
|
||||
ErrListenerIDExists = errors.New("Listener id already exists")
|
||||
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
|
||||
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
|
||||
|
||||
// ErrACLNotAuthorized = errors.New("ACL not authorized")
|
||||
)
|
||||
|
||||
/*
|
||||
var (
|
||||
ErrListenerIDExists = errors.New("Listener id already exists")
|
||||
ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet")
|
||||
ErrReadConnectPacket = errors.New("Error reading CONNECT packet")
|
||||
ErrReadConnectInvalid = errors.New("CONNECT packet was not valid")
|
||||
|
||||
ErrReadFixedHeader = errors.New("Error reading fixed header")
|
||||
ErrReadPacketPayload = errors.New("Error reading packet payload")
|
||||
ErrReadPacketValidation = errors.New("Error validating packet")
|
||||
ErrConnectionClosed = errors.New("Connection not open")
|
||||
ErrNoData = errors.New("No data")
|
||||
ErrACLNotAuthorized = errors.New("ACL not authorized")
|
||||
)
|
||||
*/
|
||||
|
||||
// Server is an MQTT broker server.
|
||||
type Server struct {
|
||||
bytepool circ.BytesPool
|
||||
Listeners listeners.Listeners // listeners listen for new connections.
|
||||
Clients clients.Clients // clients known to the broker.
|
||||
Topics *topics.Index // an index of topic subscriptions and retained messages.
|
||||
@@ -52,8 +35,8 @@ type Server struct {
|
||||
|
||||
// New returns a new instance of an MQTT broker.
|
||||
func New() *Server {
|
||||
fmt.Println()
|
||||
return &Server{
|
||||
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
|
||||
Listeners: listeners.New(),
|
||||
Clients: clients.New(),
|
||||
Topics: topics.New(),
|
||||
@@ -82,53 +65,14 @@ func (s *Server) Serve() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close attempts to gracefully shutdown the server, all listeners, and clients.
|
||||
func (s *Server) Close() error {
|
||||
s.Listeners.CloseAll(s.closeListenerClients)
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeListenerClients closes all clients on the specified listener.
|
||||
func (s *Server) closeListenerClients(listener string) {
|
||||
clients := s.Clients.GetByListener(listener)
|
||||
for _, client := range clients {
|
||||
s.closeClient(client, false) // omit errors
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// closeClient closes a client connection and publishes any LWT messages.
|
||||
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
||||
|
||||
//debug.Println(cl.ID, "SERVER STOPS ISSUED >> ")
|
||||
|
||||
// If an LWT message is set, publish it to the topic subscribers.
|
||||
/* // this currently loops forever on broken connection
|
||||
if sendLWT && cl.lwt.topic != "" {
|
||||
err := s.processPublish(cl, &packets.PublishPacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Retain: cl.lwt.retain,
|
||||
Qos: cl.lwt.qos,
|
||||
},
|
||||
TopicName: cl.lwt.topic,
|
||||
Payload: cl.lwt.message,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Stop listening for new packets.
|
||||
cl.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstablishConnection establishes a new client connection with the broker.
|
||||
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
||||
client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0))
|
||||
//client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0))
|
||||
client := clients.NewClient(c,
|
||||
circ.NewReaderFromSlice(0, s.bytepool.Get()),
|
||||
circ.NewWriterFromSlice(0, s.bytepool.Get()),
|
||||
)
|
||||
|
||||
client.Start()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
@@ -151,6 +95,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
if !ac.Authenticate(pk.Username, pk.Password) {
|
||||
retcode = packets.CodeConnectBadAuthValues
|
||||
}
|
||||
|
||||
var sessionPresent bool
|
||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||
existing.Lock()
|
||||
@@ -168,10 +113,8 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
existing.Unlock()
|
||||
}
|
||||
|
||||
// Add the new client to the clients manager.
|
||||
s.Clients.Add(client)
|
||||
|
||||
// Send a CONNACK back to the client with retcode.
|
||||
err = s.writeClient(client, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Connack,
|
||||
@@ -184,21 +127,17 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
||||
return err
|
||||
}
|
||||
|
||||
// Resend any unacknowledged QOS messages still pending for the client.
|
||||
/*err = s.resendInflight(client)
|
||||
err = s.ResendInflight(client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*/
|
||||
|
||||
// Block and listen for more packets, and end if an error or nil packet occurs.
|
||||
var sendLWT bool
|
||||
err = client.Read(s.processPacket)
|
||||
if err != nil {
|
||||
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
|
||||
}
|
||||
|
||||
// Publish last will and testament then close.
|
||||
s.closeClient(client, sendLWT)
|
||||
|
||||
return err
|
||||
@@ -217,13 +156,10 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resendInflight republishes any inflight messages to the client.
|
||||
/*func (s *Server) resendInflight(cl *clients.Client) error {
|
||||
cl.RLock()
|
||||
msgs := cl.inFlight.internal
|
||||
cl.RUnlock()
|
||||
for _, msg := range msgs {
|
||||
err := s.writeClient(cl, msg.packet)
|
||||
// ResendInflight republishes any inflight messages to the client.
|
||||
func (s *Server) ResendInflight(cl *clients.Client) error {
|
||||
for _, pk := range cl.InFlight.GetAll() {
|
||||
err := s.writeClient(cl, pk.Packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -231,7 +167,6 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
|
||||
// processPacket processes an inbound packet for a client. Since the method is
|
||||
// typically called as a goroutine, errors are mostly for test checking purposes.
|
||||
@@ -453,3 +388,45 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) (clos
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Close attempts to gracefully shutdown the server, all listeners, and clients.
|
||||
func (s *Server) Close() error {
|
||||
s.Listeners.CloseAll(s.closeListenerClients)
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeListenerClients closes all clients on the specified listener.
|
||||
func (s *Server) closeListenerClients(listener string) {
|
||||
clients := s.Clients.GetByListener(listener)
|
||||
for _, client := range clients {
|
||||
s.closeClient(client, false) // omit errors
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// closeClient closes a client connection and publishes any LWT messages.
|
||||
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
||||
// If an LWT message is set, publish it to the topic subscribers.
|
||||
|
||||
/*
|
||||
if sendLWT && cl.lwt.topic != "" {
|
||||
err := s.processPublish(cl, &packets.PublishPacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Retain: cl.lwt.retain,
|
||||
Qos: cl.lwt.qos,
|
||||
},
|
||||
TopicName: cl.lwt.topic,
|
||||
Payload: cl.lwt.message,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Stop listening for new packets.
|
||||
cl.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
142
mqtt_test.go
142
mqtt_test.go
@@ -819,7 +819,6 @@ func TestServerProcessSubscribeWriteError(t *testing.T) {
|
||||
|
||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
s, cl, r, w := setupClient()
|
||||
|
||||
s.Clients.Add(cl)
|
||||
s.Topics.Subscribe("a/b/c", cl.ID, 0)
|
||||
s.Topics.Subscribe("d/e/f", cl.ID, 1)
|
||||
@@ -880,133 +879,22 @@ func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
||||
require.Equal(t, false, close)
|
||||
}
|
||||
|
||||
/*
|
||||
func TestServerClose(t *testing.T) {
|
||||
s, cl, _, _ := setupClient()
|
||||
cl.Listener = "t1"
|
||||
s.Clients.Add(cl)
|
||||
|
||||
|
||||
|
||||
func TestServerProcessSubscribeWriteRetainedError(t *testing.T) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.p.W = &quietWriter{errAfter: 1}
|
||||
|
||||
s.topics.RetainMessage(&packets.PublishPacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Retain: true,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
})
|
||||
require.Equal(t, 1, len(s.topics.Messages("a/b/c")))
|
||||
|
||||
err := s.processPacket(cl, &packets.SubscribePacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Subscribe,
|
||||
},
|
||||
PacketID: 10,
|
||||
Topics: []string{"a/b/c", "d/e/f"},
|
||||
Qoss: []byte{0, 1},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.p.W = new(quietWriter)
|
||||
|
||||
s.clients.add(cl)
|
||||
s.topics.Subscribe("a/b/c", cl.id, 0)
|
||||
s.topics.Subscribe("d/e/f", cl.id, 1)
|
||||
cl.noteSubscription("a/b/c", 0)
|
||||
cl.noteSubscription("d/e/f", 1)
|
||||
|
||||
err := s.processPacket(cl, &packets.UnsubscribePacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Unsubscribe,
|
||||
},
|
||||
PacketID: 12,
|
||||
Topics: []string{"a/b/c", "d/e/f"},
|
||||
})
|
||||
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Unsuback << 4), 2, // Fixed header
|
||||
0, 12, // Packet ID - LSB+MSB
|
||||
}, cl.p.W.(*quietWriter).f[0])
|
||||
s.Serve()
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, 1, s.Listeners.Len())
|
||||
|
||||
require.Empty(t, s.topics.Subscribers("a/b/c"))
|
||||
require.Empty(t, s.topics.Subscribers("d/e/f"))
|
||||
require.NotContains(t, cl.subscriptions, "a/b/c")
|
||||
require.NotContains(t, cl.subscriptions, "d/e/f")
|
||||
listener, ok := s.Listeners.Get("t1")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, true, listener.(*listeners.MockListener).IsServing)
|
||||
|
||||
s.Close()
|
||||
time.Sleep(time.Millisecond)
|
||||
require.Equal(t, false, listener.(*listeners.MockListener).IsServing)
|
||||
}
|
||||
|
||||
func BenchmarkServerProcessUnsubscribe(b *testing.B) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.p.W = new(quietWriter)
|
||||
|
||||
pk := &packets.UnsubscribePacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Unsubscribe,
|
||||
},
|
||||
PacketID: 12,
|
||||
Topics: []string{"a/b/c"},
|
||||
}
|
||||
for n := 0; n < b.N; n++ {
|
||||
err := s.processUnsubscribe(cl, pk)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.p.W = &quietWriter{errAfter: -1}
|
||||
err := s.processPacket(cl, &packets.UnsubscribePacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Unsubscribe,
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
|
||||
func TestResendInflight(t *testing.T) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.inFlight.set(1, &inFlightMessage{
|
||||
packet: &packets.PublishPacket{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Qos: 1,
|
||||
Retain: true,
|
||||
Dup: true,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
PacketID: 1,
|
||||
},
|
||||
sent: time.Now().Unix(),
|
||||
})
|
||||
|
||||
err := s.resendInflight(cl)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte{
|
||||
byte(packets.Publish<<4 | 11), 14, // Fixed header QoS : 1
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
0, 1, // packet id from qos=1
|
||||
'h', 'e', 'l', 'l', 'o', // Payload)
|
||||
}, cl.p.W.Get()[:16])
|
||||
}
|
||||
|
||||
func TestResendInflightWriteError(t *testing.T) {
|
||||
s, _, _, cl := setupClient("zen")
|
||||
cl.inFlight.set(1, &inFlightMessage{
|
||||
packet: &packets.PublishPacket{},
|
||||
})
|
||||
|
||||
cl.p.W.Close()
|
||||
err := s.resendInflight(cl)
|
||||
require.Error(t, err)
|
||||
}
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user