mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-25 00:50:31 +08:00
bytes buffer to pool
This commit is contained in:
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -11,9 +12,14 @@ import (
|
|||||||
|
|
||||||
"github.com/mochi-co/mqtt"
|
"github.com/mochi-co/mqtt"
|
||||||
"github.com/mochi-co/mqtt/internal/listeners"
|
"github.com/mochi-co/mqtt/internal/listeners"
|
||||||
|
|
||||||
|
_ "net/http/pprof"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||||
|
}()
|
||||||
|
|
||||||
sigs := make(chan os.Signal, 1)
|
sigs := make(chan os.Signal, 1)
|
||||||
done := make(chan bool, 1)
|
done := make(chan bool, 1)
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -6,6 +6,7 @@ require (
|
|||||||
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
|
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
|
||||||
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
|
github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23
|
||||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164
|
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164
|
||||||
|
github.com/pkg/profile v1.4.0
|
||||||
github.com/rs/xid v1.2.1
|
github.com/rs/xid v1.2.1
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.4.0
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -8,6 +8,8 @@ github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8 h1:BIY2BMCLHm6hE/SU
|
|||||||
github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8/go.mod h1:AqE7zHPhLOj61seX0vXvzpGiD9Q3Bx5LQPf/FleHKWc=
|
github.com/mochi-co/debug v0.0.0-20191124114744-82bf8b6739b8/go.mod h1:AqE7zHPhLOj61seX0vXvzpGiD9Q3Bx5LQPf/FleHKWc=
|
||||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 h1:XGYo79ZRE9pQE9B5iZCYw3VLaq88PfxcdvDf9crG+dQ=
|
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164 h1:XGYo79ZRE9pQE9B5iZCYw3VLaq88PfxcdvDf9crG+dQ=
|
||||||
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164/go.mod h1:LfBrWXdsMaDKL0ZjcbnLjeYL48Nlo1nW4MltMDYqr44=
|
github.com/mochi-co/debug v0.0.0-20191124131204-24fd1e001164/go.mod h1:LfBrWXdsMaDKL0ZjcbnLjeYL48Nlo1nW4MltMDYqr44=
|
||||||
|
github.com/pkg/profile v1.4.0 h1:uCmaf4vVbWAOZz36k1hrQD7ijGRzLwaME8Am/7a4jZI=
|
||||||
|
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
|
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
DefaultBufferSize int = 2048 // the default size of the buffer in bytes.
|
DefaultBufferSize int = 1024 * 256 // the default size of the buffer in bytes.
|
||||||
DefaultBlockSize int = 128 // the default size per R/W block in bytes.
|
DefaultBlockSize int = 1024 * 8 // the default size per R/W block in bytes.
|
||||||
|
|
||||||
ErrOutOfRange = fmt.Errorf("Indexes out of range")
|
ErrOutOfRange = fmt.Errorf("Indexes out of range")
|
||||||
ErrInsufficientBytes = fmt.Errorf("Insufficient bytes to return")
|
ErrInsufficientBytes = fmt.Errorf("Insufficient bytes to return")
|
||||||
@@ -43,6 +43,7 @@ func NewBuffer(size, block int) Buffer {
|
|||||||
if block == 0 {
|
if block == 0 {
|
||||||
block = DefaultBlockSize
|
block = DefaultBlockSize
|
||||||
}
|
}
|
||||||
|
|
||||||
if size < 2*block {
|
if size < 2*block {
|
||||||
size = 2 * block
|
size = 2 * block
|
||||||
}
|
}
|
||||||
@@ -52,12 +53,32 @@ func NewBuffer(size, block int) Buffer {
|
|||||||
mask: size - 1,
|
mask: size - 1,
|
||||||
block: block,
|
block: block,
|
||||||
buf: make([]byte, size),
|
buf: make([]byte, size),
|
||||||
tmp: make([]byte, size),
|
|
||||||
rcond: sync.NewCond(new(sync.Mutex)),
|
rcond: sync.NewCond(new(sync.Mutex)),
|
||||||
wcond: sync.NewCond(new(sync.Mutex)),
|
wcond: sync.NewCond(new(sync.Mutex)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewBufferFromSlice returns a new instance of buffer using a
|
||||||
|
// pre-existing byte slice.
|
||||||
|
func NewBufferFromSlice(block int, buf []byte) Buffer {
|
||||||
|
l := len(buf)
|
||||||
|
|
||||||
|
if block == 0 {
|
||||||
|
block = DefaultBlockSize
|
||||||
|
}
|
||||||
|
|
||||||
|
b := Buffer{
|
||||||
|
size: l,
|
||||||
|
mask: l - 1,
|
||||||
|
block: block,
|
||||||
|
buf: buf,
|
||||||
|
rcond: sync.NewCond(new(sync.Mutex)),
|
||||||
|
wcond: sync.NewCond(new(sync.Mutex)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Get will return the tail and head positions of the buffer.
|
// Get will return the tail and head positions of the buffer.
|
||||||
// This method is for use with testing.
|
// This method is for use with testing.
|
||||||
func (b *Buffer) GetPos() (int64, int64) {
|
func (b *Buffer) GetPos() (int64, int64) {
|
||||||
|
|||||||
@@ -36,6 +36,20 @@ func TestNewBufferUndersize(t *testing.T) {
|
|||||||
require.Equal(t, DefaultBlockSize, buf.block)
|
require.Equal(t, DefaultBlockSize, buf.block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewBufferFromSlice(t *testing.T) {
|
||||||
|
b := NewBytesPool(256)
|
||||||
|
buf := NewBufferFromSlice(DefaultBlockSize, b.Get())
|
||||||
|
require.NotNil(t, buf.buf)
|
||||||
|
require.Equal(t, 256, cap(buf.buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBufferFromSlice0Size(t *testing.T) {
|
||||||
|
b := NewBytesPool(256)
|
||||||
|
buf := NewBufferFromSlice(0, b.Get())
|
||||||
|
require.NotNil(t, buf.buf)
|
||||||
|
require.Equal(t, 256, cap(buf.buf))
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetPos(t *testing.T) {
|
func TestGetPos(t *testing.T) {
|
||||||
buf := NewBuffer(16, 4)
|
buf := NewBuffer(16, 4)
|
||||||
tail, head := buf.GetPos()
|
tail, head := buf.GetPos()
|
||||||
|
|||||||
32
internal/circ/pool.go
Normal file
32
internal/circ/pool.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package circ
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BytesPool is a pool of []byte
|
||||||
|
type BytesPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBytesPool returns a sync.pool of []byte
|
||||||
|
func NewBytesPool(n int) BytesPool {
|
||||||
|
return BytesPool{
|
||||||
|
pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, n)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a pooled bytes.Buffer.
|
||||||
|
func (b BytesPool) Get() []byte {
|
||||||
|
return b.pool.Get().([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put puts the byte slice back into the pool.
|
||||||
|
func (b BytesPool) Put(x []byte) {
|
||||||
|
x = x[:0]
|
||||||
|
b.pool.Put(x)
|
||||||
|
}
|
||||||
46
internal/circ/pool_test.go
Normal file
46
internal/circ/pool_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package circ
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBytesPool(t *testing.T) {
|
||||||
|
bpool := NewBytesPool(256)
|
||||||
|
require.NotNil(t, bpool.pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNewBytesPool(b *testing.B) {
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
NewBytesPool(256)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBytesPoolGet(t *testing.T) {
|
||||||
|
bpool := NewBytesPool(256)
|
||||||
|
buf := bpool.Get()
|
||||||
|
|
||||||
|
require.Equal(t, make([]byte, 256), buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBytesPoolGet(b *testing.B) {
|
||||||
|
bpool := NewBytesPool(256)
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
bpool.Get()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBytesPoolPut(t *testing.T) {
|
||||||
|
bpool := NewBytesPool(256)
|
||||||
|
buf := bpool.Get()
|
||||||
|
bpool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBytesPoolPut(b *testing.B) {
|
||||||
|
bpool := NewBytesPool(256)
|
||||||
|
buf := bpool.Get()
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
bpool.Put(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,8 +3,6 @@ package circ
|
|||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
dbg "github.com/mochi-co/debug"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reader is a circular buffer for reading data from an io.Reader.
|
// Reader is a circular buffer for reading data from an io.Reader.
|
||||||
@@ -12,7 +10,7 @@ type Reader struct {
|
|||||||
Buffer
|
Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReader returns a pointer to a new Circular Reader.
|
// NewReader returns a new Circular Reader.
|
||||||
func NewReader(size, block int) *Reader {
|
func NewReader(size, block int) *Reader {
|
||||||
b := NewBuffer(size, block)
|
b := NewBuffer(size, block)
|
||||||
b.ID = "\treader"
|
b.ID = "\treader"
|
||||||
@@ -21,6 +19,16 @@ func NewReader(size, block int) *Reader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewReaderFromSlice returns a new Circular Reader using a pre-exising
|
||||||
|
// byte slice.
|
||||||
|
func NewReaderFromSlice(block int, p []byte) *Reader {
|
||||||
|
b := NewBufferFromSlice(block, p)
|
||||||
|
b.ID = "\treader"
|
||||||
|
return &Reader{
|
||||||
|
b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||||
// there is sufficient capacity to do so.
|
// there is sufficient capacity to do so.
|
||||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||||
@@ -48,8 +56,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
|||||||
end = b.size
|
end = b.size
|
||||||
}
|
}
|
||||||
|
|
||||||
dbg.Println(dbg.Yellow, b.ID, "b.ReadFrom allocating", start, ":", end)
|
|
||||||
|
|
||||||
// Read into the buffer between the start and end indexes only.
|
// Read into the buffer between the start and end indexes only.
|
||||||
n, err := r.Read(b.buf[start:end])
|
n, err := r.Read(b.buf[start:end])
|
||||||
total += int64(n) // incr total bytes read.
|
total += int64(n) // incr total bytes read.
|
||||||
@@ -57,8 +63,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
|||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dbg.Println(dbg.HiYellow, b.ID, "b.ReadFrom received", n, b.buf[start:start+n])
|
|
||||||
|
|
||||||
// Move the head forward however many bytes were read.
|
// Move the head forward however many bytes were read.
|
||||||
atomic.AddInt64(&b.head, int64(n))
|
atomic.AddInt64(&b.head, int64(n))
|
||||||
|
|
||||||
@@ -71,8 +75,6 @@ func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
|||||||
// Read reads n bytes from the buffer, and will block until at n bytes
|
// Read reads n bytes from the buffer, and will block until at n bytes
|
||||||
// exist in the buffer to read.
|
// exist in the buffer to read.
|
||||||
func (b *Buffer) Read(n int) (p []byte, err error) {
|
func (b *Buffer) Read(n int) (p []byte, err error) {
|
||||||
dbg.Println(dbg.Cyan, b.ID, "b.Read waiting for", n, "bytes")
|
|
||||||
|
|
||||||
err = b.awaitFilled(n)
|
err = b.awaitFilled(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -90,7 +92,5 @@ func (b *Buffer) Read(n int) (p []byte, err error) {
|
|||||||
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
|
b.tmp = b.buf[b.Index(tail):b.Index(next)] // Otherwise, simple tail:next read.
|
||||||
}
|
}
|
||||||
|
|
||||||
dbg.Println(dbg.HiCyan, b.ID, "b.Read read", tail, next, b.tmp)
|
|
||||||
|
|
||||||
return b.tmp, nil
|
return b.tmp, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,13 @@ func TestNewReader(t *testing.T) {
|
|||||||
require.Equal(t, block, buf.block)
|
require.Equal(t, block, buf.block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewReaderFromSlice(t *testing.T) {
|
||||||
|
b := NewBytesPool(256)
|
||||||
|
buf := NewReaderFromSlice(DefaultBlockSize, b.Get())
|
||||||
|
require.NotNil(t, buf.buf)
|
||||||
|
require.Equal(t, 256, cap(buf.buf))
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadFrom(t *testing.T) {
|
func TestReadFrom(t *testing.T) {
|
||||||
buf := NewReader(16, 4)
|
buf := NewReader(16, 4)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,16 @@ func NewWriter(size, block int) *Writer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewWriterFromSlice returns a new Circular Writer using a pre-exising
|
||||||
|
// byte slice.
|
||||||
|
func NewWriterFromSlice(block int, p []byte) *Writer {
|
||||||
|
b := NewBufferFromSlice(block, p)
|
||||||
|
b.ID = "writer"
|
||||||
|
return &Writer{
|
||||||
|
b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||||
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
func (b *Writer) WriteTo(w io.Writer) (total int, err error) {
|
||||||
atomic.StoreInt64(&b.State, 2)
|
atomic.StoreInt64(&b.State, 2)
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ func TestNewWriter(t *testing.T) {
|
|||||||
require.Equal(t, block, buf.block)
|
require.Equal(t, block, buf.block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewWriterFromSlice(t *testing.T) {
|
||||||
|
b := NewBytesPool(256)
|
||||||
|
buf := NewWriterFromSlice(DefaultBlockSize, b.Get())
|
||||||
|
require.NotNil(t, buf.buf)
|
||||||
|
require.Equal(t, 256, cap(buf.buf))
|
||||||
|
}
|
||||||
|
|
||||||
func TestWriteTo(t *testing.T) {
|
func TestWriteTo(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
tail int64
|
tail int64
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
dbg "github.com/mochi-co/debug"
|
|
||||||
"github.com/mochi-co/mqtt/internal/auth"
|
"github.com/mochi-co/mqtt/internal/auth"
|
||||||
"github.com/mochi-co/mqtt/internal/circ"
|
"github.com/mochi-co/mqtt/internal/circ"
|
||||||
"github.com/mochi-co/mqtt/internal/packets"
|
"github.com/mochi-co/mqtt/internal/packets"
|
||||||
@@ -210,27 +209,26 @@ func (cl *Client) Start() {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
cl.state.started.Done()
|
cl.state.started.Done()
|
||||||
_, err := cl.w.WriteTo(cl.conn)
|
//_, err :=
|
||||||
dbg.Println(dbg.HiRed, cl.ID, "WriteTo stopped", err)
|
cl.w.WriteTo(cl.conn)
|
||||||
cl.state.endedW.Done()
|
cl.state.endedW.Done()
|
||||||
//cl.close()
|
//cl.close()
|
||||||
}()
|
}()
|
||||||
cl.state.endedW.Add(1)
|
cl.state.endedW.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
cl.state.started.Done()
|
cl.state.started.Done()
|
||||||
_, err := cl.r.ReadFrom(cl.conn)
|
//_, err :=
|
||||||
dbg.Println(dbg.HiRed, cl.ID, "ReadFrom stopped", err)
|
cl.r.ReadFrom(cl.conn)
|
||||||
cl.state.endedR.Done()
|
cl.state.endedR.Done()
|
||||||
//cl.close()
|
//cl.close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cl.state.endedR.Add(1)
|
cl.state.endedR.Add(1)
|
||||||
cl.state.started.Wait()
|
cl.state.started.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||||
func (cl *Client) Stop() {
|
func (cl *Client) Stop() {
|
||||||
dbg.Println(dbg.HiRed+"CLIENT stop called...", dbg.Underline+cl.ID)
|
|
||||||
cl.r.Stop()
|
cl.r.Stop()
|
||||||
cl.w.Stop()
|
cl.w.Stop()
|
||||||
cl.state.endedW.Wait()
|
cl.state.endedW.Wait()
|
||||||
@@ -241,7 +239,6 @@ func (cl *Client) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cl.state.endedR.Wait()
|
cl.state.endedR.Wait()
|
||||||
dbg.Println(dbg.HiRed+"CLIENT stopped", dbg.Underline+cl.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// readFixedHeader reads in the values of the next packet's fixed header.
|
// readFixedHeader reads in the values of the next packet's fixed header.
|
||||||
|
|||||||
135
mqtt.go
135
mqtt.go
@@ -23,28 +23,11 @@ var (
|
|||||||
ErrListenerIDExists = errors.New("Listener id already exists")
|
ErrListenerIDExists = errors.New("Listener id already exists")
|
||||||
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
|
ErrReadConnectInvalid = errors.New("Connect packet was not valid")
|
||||||
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
|
ErrConnectNotAuthorized = errors.New("Connect packet was not authorized")
|
||||||
|
|
||||||
// ErrACLNotAuthorized = errors.New("ACL not authorized")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
|
||||||
var (
|
|
||||||
ErrListenerIDExists = errors.New("Listener id already exists")
|
|
||||||
ErrReadConnectFixedHeader = errors.New("Error reading fixed header on CONNECT packet")
|
|
||||||
ErrReadConnectPacket = errors.New("Error reading CONNECT packet")
|
|
||||||
ErrReadConnectInvalid = errors.New("CONNECT packet was not valid")
|
|
||||||
|
|
||||||
ErrReadFixedHeader = errors.New("Error reading fixed header")
|
|
||||||
ErrReadPacketPayload = errors.New("Error reading packet payload")
|
|
||||||
ErrReadPacketValidation = errors.New("Error validating packet")
|
|
||||||
ErrConnectionClosed = errors.New("Connection not open")
|
|
||||||
ErrNoData = errors.New("No data")
|
|
||||||
ErrACLNotAuthorized = errors.New("ACL not authorized")
|
|
||||||
)
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Server is an MQTT broker server.
|
// Server is an MQTT broker server.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
bytepool circ.BytesPool
|
||||||
Listeners listeners.Listeners // listeners listen for new connections.
|
Listeners listeners.Listeners // listeners listen for new connections.
|
||||||
Clients clients.Clients // clients known to the broker.
|
Clients clients.Clients // clients known to the broker.
|
||||||
Topics *topics.Index // an index of topic subscriptions and retained messages.
|
Topics *topics.Index // an index of topic subscriptions and retained messages.
|
||||||
@@ -52,8 +35,8 @@ type Server struct {
|
|||||||
|
|
||||||
// New returns a new instance of an MQTT broker.
|
// New returns a new instance of an MQTT broker.
|
||||||
func New() *Server {
|
func New() *Server {
|
||||||
fmt.Println()
|
|
||||||
return &Server{
|
return &Server{
|
||||||
|
bytepool: circ.NewBytesPool(circ.DefaultBufferSize),
|
||||||
Listeners: listeners.New(),
|
Listeners: listeners.New(),
|
||||||
Clients: clients.New(),
|
Clients: clients.New(),
|
||||||
Topics: topics.New(),
|
Topics: topics.New(),
|
||||||
@@ -82,53 +65,14 @@ func (s *Server) Serve() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close attempts to gracefully shutdown the server, all listeners, and clients.
|
|
||||||
func (s *Server) Close() error {
|
|
||||||
s.Listeners.CloseAll(s.closeListenerClients)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeListenerClients closes all clients on the specified listener.
|
|
||||||
func (s *Server) closeListenerClients(listener string) {
|
|
||||||
clients := s.Clients.GetByListener(listener)
|
|
||||||
for _, client := range clients {
|
|
||||||
s.closeClient(client, false) // omit errors
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeClient closes a client connection and publishes any LWT messages.
|
|
||||||
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
|
||||||
|
|
||||||
//debug.Println(cl.ID, "SERVER STOPS ISSUED >> ")
|
|
||||||
|
|
||||||
// If an LWT message is set, publish it to the topic subscribers.
|
|
||||||
/* // this currently loops forever on broken connection
|
|
||||||
if sendLWT && cl.lwt.topic != "" {
|
|
||||||
err := s.processPublish(cl, &packets.PublishPacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Publish,
|
|
||||||
Retain: cl.lwt.retain,
|
|
||||||
Qos: cl.lwt.qos,
|
|
||||||
},
|
|
||||||
TopicName: cl.lwt.topic,
|
|
||||||
Payload: cl.lwt.message,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Stop listening for new packets.
|
|
||||||
cl.Stop()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EstablishConnection establishes a new client connection with the broker.
|
// EstablishConnection establishes a new client connection with the broker.
|
||||||
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
||||||
client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0))
|
//client := clients.NewClient(c, circ.NewReader(0, 0), circ.NewWriter(0, 0))
|
||||||
|
client := clients.NewClient(c,
|
||||||
|
circ.NewReaderFromSlice(0, s.bytepool.Get()),
|
||||||
|
circ.NewWriterFromSlice(0, s.bytepool.Get()),
|
||||||
|
)
|
||||||
|
|
||||||
client.Start()
|
client.Start()
|
||||||
|
|
||||||
fh := new(packets.FixedHeader)
|
fh := new(packets.FixedHeader)
|
||||||
@@ -151,6 +95,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
|||||||
if !ac.Authenticate(pk.Username, pk.Password) {
|
if !ac.Authenticate(pk.Username, pk.Password) {
|
||||||
retcode = packets.CodeConnectBadAuthValues
|
retcode = packets.CodeConnectBadAuthValues
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionPresent bool
|
var sessionPresent bool
|
||||||
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
if existing, ok := s.Clients.Get(pk.ClientIdentifier); ok {
|
||||||
existing.Lock()
|
existing.Lock()
|
||||||
@@ -168,10 +113,8 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
|||||||
existing.Unlock()
|
existing.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the new client to the clients manager.
|
|
||||||
s.Clients.Add(client)
|
s.Clients.Add(client)
|
||||||
|
|
||||||
// Send a CONNACK back to the client with retcode.
|
|
||||||
err = s.writeClient(client, packets.Packet{
|
err = s.writeClient(client, packets.Packet{
|
||||||
FixedHeader: packets.FixedHeader{
|
FixedHeader: packets.FixedHeader{
|
||||||
Type: packets.Connack,
|
Type: packets.Connack,
|
||||||
@@ -184,21 +127,17 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resend any unacknowledged QOS messages still pending for the client.
|
err = s.ResendInflight(client)
|
||||||
/*err = s.resendInflight(client)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
// Block and listen for more packets, and end if an error or nil packet occurs.
|
|
||||||
var sendLWT bool
|
var sendLWT bool
|
||||||
err = client.Read(s.processPacket)
|
err = client.Read(s.processPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
|
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish last will and testament then close.
|
|
||||||
s.closeClient(client, sendLWT)
|
s.closeClient(client, sendLWT)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -217,13 +156,10 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resendInflight republishes any inflight messages to the client.
|
// ResendInflight republishes any inflight messages to the client.
|
||||||
/*func (s *Server) resendInflight(cl *clients.Client) error {
|
func (s *Server) ResendInflight(cl *clients.Client) error {
|
||||||
cl.RLock()
|
for _, pk := range cl.InFlight.GetAll() {
|
||||||
msgs := cl.inFlight.internal
|
err := s.writeClient(cl, pk.Packet)
|
||||||
cl.RUnlock()
|
|
||||||
for _, msg := range msgs {
|
|
||||||
err := s.writeClient(cl, msg.packet)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -231,7 +167,6 @@ func (s *Server) writeClient(cl *clients.Client, pk packets.Packet) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
// processPacket processes an inbound packet for a client. Since the method is
|
// processPacket processes an inbound packet for a client. Since the method is
|
||||||
// typically called as a goroutine, errors are mostly for test checking purposes.
|
// typically called as a goroutine, errors are mostly for test checking purposes.
|
||||||
@@ -453,3 +388,45 @@ func (s *Server) processUnsubscribe(cl *clients.Client, pk packets.Packet) (clos
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close attempts to gracefully shutdown the server, all listeners, and clients.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
s.Listeners.CloseAll(s.closeListenerClients)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeListenerClients closes all clients on the specified listener.
|
||||||
|
func (s *Server) closeListenerClients(listener string) {
|
||||||
|
clients := s.Clients.GetByListener(listener)
|
||||||
|
for _, client := range clients {
|
||||||
|
s.closeClient(client, false) // omit errors
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeClient closes a client connection and publishes any LWT messages.
|
||||||
|
func (s *Server) closeClient(cl *clients.Client, sendLWT bool) error {
|
||||||
|
// If an LWT message is set, publish it to the topic subscribers.
|
||||||
|
|
||||||
|
/*
|
||||||
|
if sendLWT && cl.lwt.topic != "" {
|
||||||
|
err := s.processPublish(cl, &packets.PublishPacket{
|
||||||
|
FixedHeader: packets.FixedHeader{
|
||||||
|
Type: packets.Publish,
|
||||||
|
Retain: cl.lwt.retain,
|
||||||
|
Qos: cl.lwt.qos,
|
||||||
|
},
|
||||||
|
TopicName: cl.lwt.topic,
|
||||||
|
Payload: cl.lwt.message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Stop listening for new packets.
|
||||||
|
cl.Stop()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
142
mqtt_test.go
142
mqtt_test.go
@@ -819,7 +819,6 @@ func TestServerProcessSubscribeWriteError(t *testing.T) {
|
|||||||
|
|
||||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
func TestServerProcessUnsubscribe(t *testing.T) {
|
||||||
s, cl, r, w := setupClient()
|
s, cl, r, w := setupClient()
|
||||||
|
|
||||||
s.Clients.Add(cl)
|
s.Clients.Add(cl)
|
||||||
s.Topics.Subscribe("a/b/c", cl.ID, 0)
|
s.Topics.Subscribe("a/b/c", cl.ID, 0)
|
||||||
s.Topics.Subscribe("d/e/f", cl.ID, 1)
|
s.Topics.Subscribe("d/e/f", cl.ID, 1)
|
||||||
@@ -880,133 +879,22 @@ func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
|||||||
require.Equal(t, false, close)
|
require.Equal(t, false, close)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
func TestServerClose(t *testing.T) {
|
||||||
|
s, cl, _, _ := setupClient()
|
||||||
|
cl.Listener = "t1"
|
||||||
|
s.Clients.Add(cl)
|
||||||
|
|
||||||
|
err := s.AddListener(listeners.NewMockListener("t1", ":1882"), nil)
|
||||||
|
|
||||||
func TestServerProcessSubscribeWriteRetainedError(t *testing.T) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.p.W = &quietWriter{errAfter: 1}
|
|
||||||
|
|
||||||
s.topics.RetainMessage(&packets.PublishPacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Publish,
|
|
||||||
Retain: true,
|
|
||||||
},
|
|
||||||
TopicName: "a/b/c",
|
|
||||||
Payload: []byte("hello"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 1, len(s.topics.Messages("a/b/c")))
|
|
||||||
|
|
||||||
err := s.processPacket(cl, &packets.SubscribePacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Subscribe,
|
|
||||||
},
|
|
||||||
PacketID: 10,
|
|
||||||
Topics: []string{"a/b/c", "d/e/f"},
|
|
||||||
Qoss: []byte{0, 1},
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerProcessUnsubscribe(t *testing.T) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.p.W = new(quietWriter)
|
|
||||||
|
|
||||||
s.clients.add(cl)
|
|
||||||
s.topics.Subscribe("a/b/c", cl.id, 0)
|
|
||||||
s.topics.Subscribe("d/e/f", cl.id, 1)
|
|
||||||
cl.noteSubscription("a/b/c", 0)
|
|
||||||
cl.noteSubscription("d/e/f", 1)
|
|
||||||
|
|
||||||
err := s.processPacket(cl, &packets.UnsubscribePacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Unsubscribe,
|
|
||||||
},
|
|
||||||
PacketID: 12,
|
|
||||||
Topics: []string{"a/b/c", "d/e/f"},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []byte{
|
s.Serve()
|
||||||
byte(packets.Unsuback << 4), 2, // Fixed header
|
time.Sleep(time.Millisecond)
|
||||||
0, 12, // Packet ID - LSB+MSB
|
require.Equal(t, 1, s.Listeners.Len())
|
||||||
}, cl.p.W.(*quietWriter).f[0])
|
|
||||||
|
|
||||||
require.Empty(t, s.topics.Subscribers("a/b/c"))
|
listener, ok := s.Listeners.Get("t1")
|
||||||
require.Empty(t, s.topics.Subscribers("d/e/f"))
|
require.Equal(t, true, ok)
|
||||||
require.NotContains(t, cl.subscriptions, "a/b/c")
|
require.Equal(t, true, listener.(*listeners.MockListener).IsServing)
|
||||||
require.NotContains(t, cl.subscriptions, "d/e/f")
|
|
||||||
|
s.Close()
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
require.Equal(t, false, listener.(*listeners.MockListener).IsServing)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkServerProcessUnsubscribe(b *testing.B) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.p.W = new(quietWriter)
|
|
||||||
|
|
||||||
pk := &packets.UnsubscribePacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Unsubscribe,
|
|
||||||
},
|
|
||||||
PacketID: 12,
|
|
||||||
Topics: []string{"a/b/c"},
|
|
||||||
}
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
err := s.processUnsubscribe(cl, pk)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.p.W = &quietWriter{errAfter: -1}
|
|
||||||
err := s.processPacket(cl, &packets.UnsubscribePacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Unsubscribe,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
func TestResendInflight(t *testing.T) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.inFlight.set(1, &inFlightMessage{
|
|
||||||
packet: &packets.PublishPacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Publish,
|
|
||||||
Qos: 1,
|
|
||||||
Retain: true,
|
|
||||||
Dup: true,
|
|
||||||
},
|
|
||||||
TopicName: "a/b/c",
|
|
||||||
Payload: []byte("hello"),
|
|
||||||
PacketID: 1,
|
|
||||||
},
|
|
||||||
sent: time.Now().Unix(),
|
|
||||||
})
|
|
||||||
|
|
||||||
err := s.resendInflight(cl)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []byte{
|
|
||||||
byte(packets.Publish<<4 | 11), 14, // Fixed header QoS : 1
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'a', '/', 'b', '/', 'c', // Topic Name
|
|
||||||
0, 1, // packet id from qos=1
|
|
||||||
'h', 'e', 'l', 'l', 'o', // Payload)
|
|
||||||
}, cl.p.W.Get()[:16])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResendInflightWriteError(t *testing.T) {
|
|
||||||
s, _, _, cl := setupClient("zen")
|
|
||||||
cl.inFlight.set(1, &inFlightMessage{
|
|
||||||
packet: &packets.PublishPacket{},
|
|
||||||
})
|
|
||||||
|
|
||||||
cl.p.W.Close()
|
|
||||||
err := s.resendInflight(cl)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|||||||
Reference in New Issue
Block a user