mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-10-04 07:46:34 +08:00
Debug func, better closers
This commit is contained in:
@@ -1,25 +1,28 @@
|
|||||||
package circ
|
package circ
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
||||||
// blockSize is the default size of bytes per R/W block.
|
|
||||||
blockSize int64 = 8192
|
|
||||||
|
|
||||||
// bufferSize is the default size of the buffer.
|
// bufferSize is the default size of the buffer.
|
||||||
bufferSize int64 = 1024 * 256
|
bufferSize int64 = 1024 * 256
|
||||||
|
|
||||||
|
// blockSize is the default size of bytes per R/W block.
|
||||||
|
blockSize int64 = 2048
|
||||||
)
|
)
|
||||||
|
|
||||||
// buffer contains core values and methods to be included in a reader or writer.
|
// buffer contains core values and methods to be included in a reader or writer.
|
||||||
type buffer struct {
|
type buffer struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
|
// id is the identifier of the buffer. This is used in debug output.
|
||||||
|
id string
|
||||||
|
|
||||||
// size is the size of the buffer.
|
// size is the size of the buffer.
|
||||||
size int64
|
size int64
|
||||||
|
|
||||||
@@ -73,11 +76,32 @@ func newBuffer(size, block int64) buffer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close notifies the buffer event loops to stop.
|
||||||
|
func (b *buffer) Close() {
|
||||||
|
atomic.StoreInt64(&b.done, 1)
|
||||||
|
debug.Println(b.id, "[B] STORE done=1 buffer closing")
|
||||||
|
|
||||||
|
b.wcond.L.Lock()
|
||||||
|
b.wcond.Broadcast()
|
||||||
|
b.wcond.L.Unlock()
|
||||||
|
b.rcond.L.Lock()
|
||||||
|
b.rcond.Broadcast()
|
||||||
|
b.rcond.L.Unlock()
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] DONE REBROADCASTED")
|
||||||
|
}
|
||||||
|
|
||||||
// Get will return the tail and head positions of the buffer.
|
// Get will return the tail and head positions of the buffer.
|
||||||
func (b *buffer) GetPos() (int64, int64) {
|
func (b *buffer) GetPos() (int64, int64) {
|
||||||
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
|
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPos sets the head and tail of the buffer. This method should only be
|
||||||
|
// used for testing.
|
||||||
|
func (b *buffer) SetPos(tail, head int64) {
|
||||||
|
b.tail, b.head = tail, head
|
||||||
|
}
|
||||||
|
|
||||||
// Set writes bytes to a byte buffer. This method should only be used for testing
|
// Set writes bytes to a byte buffer. This method should only be used for testing
|
||||||
// and will panic if out of range.
|
// and will panic if out of range.
|
||||||
func (b *buffer) Set(p []byte, start, end int) {
|
func (b *buffer) Set(p []byte, start, end int) {
|
||||||
@@ -88,12 +112,6 @@ func (b *buffer) Set(p []byte, start, end int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPos sets the head and tail of the buffer. This method should only be
|
|
||||||
// used for testing.
|
|
||||||
func (b *buffer) SetPos(tail, head int64) {
|
|
||||||
b.tail, b.head = tail, head
|
|
||||||
}
|
|
||||||
|
|
||||||
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
|
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
|
||||||
// until there are at least n bytes to write without overrunning the tail.
|
// until there are at least n bytes to write without overrunning the tail.
|
||||||
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
||||||
@@ -102,16 +120,21 @@ func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
|||||||
wrapped := next - b.size
|
wrapped := next - b.size
|
||||||
tail := atomic.LoadInt64(&b.tail)
|
tail := atomic.LoadInt64(&b.tail)
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] awaiting capacity (n)", n)
|
||||||
b.rcond.L.Lock()
|
b.rcond.L.Lock()
|
||||||
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
|
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
|
||||||
|
debug.Println(b.id, "[B] iter no capacity")
|
||||||
|
|
||||||
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
|
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
fmt.Println("ending")
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] iter no capacity waiting")
|
||||||
b.rcond.Wait()
|
b.rcond.Wait()
|
||||||
}
|
}
|
||||||
b.rcond.L.Unlock()
|
b.rcond.L.Unlock()
|
||||||
|
debug.Println(b.id, "[B] capacity unlocked (tail)", tail)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -122,15 +145,20 @@ func (b *buffer) awaitFilled(n int64) (tail int64, err error) {
|
|||||||
head := atomic.LoadInt64(&b.head)
|
head := atomic.LoadInt64(&b.head)
|
||||||
tail = atomic.LoadInt64(&b.tail)
|
tail = atomic.LoadInt64(&b.tail)
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] awaiting filled (tail, head)", tail, head)
|
||||||
b.wcond.L.Lock()
|
b.wcond.L.Lock()
|
||||||
for ; head > tail && tail+n > head || head < tail && b.size-tail+head < n; head = atomic.LoadInt64(&b.head) {
|
for ; head > tail && tail+n > head || head < tail && b.size-tail+head < n; head = atomic.LoadInt64(&b.head) {
|
||||||
|
debug.Println(b.id, "[B] iter no fill")
|
||||||
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
fmt.Println("ending")
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] iter no fill waiting")
|
||||||
b.wcond.Wait()
|
b.wcond.Wait()
|
||||||
}
|
}
|
||||||
b.wcond.L.Unlock()
|
b.wcond.L.Unlock()
|
||||||
|
debug.Println(b.id, "[B] filled (head)", head)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -144,16 +172,20 @@ func (b *buffer) awaitFilled(n int64) (tail int64, err error) {
|
|||||||
// CommitTail moves the tail position of the buffer n bytes, and will wait until
|
// CommitTail moves the tail position of the buffer n bytes, and will wait until
|
||||||
// there is enough capacity for at least n bytes.
|
// there is enough capacity for at least n bytes.
|
||||||
func (b *buffer) CommitTail(n int64) error {
|
func (b *buffer) CommitTail(n int64) error {
|
||||||
|
debug.Println(b.id, "[B] commit/discard (n)", n)
|
||||||
_, err := b.awaitFilled(n)
|
_, err := b.awaitFilled(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
debug.Println(b.id, "[B] commit/discard unlocked")
|
||||||
|
|
||||||
tail := atomic.LoadInt64(&b.tail)
|
tail := atomic.LoadInt64(&b.tail)
|
||||||
if tail+n < b.size {
|
if tail+n < b.size {
|
||||||
atomic.StoreInt64(&b.tail, tail+n)
|
atomic.StoreInt64(&b.tail, tail+n)
|
||||||
|
debug.Println(b.id, "[B] STORE TAIL", tail+n)
|
||||||
} else {
|
} else {
|
||||||
atomic.StoreInt64(&b.tail, (tail+n)%b.size)
|
atomic.StoreInt64(&b.tail, (tail+n)%b.size)
|
||||||
|
debug.Println(b.id, "[B] STORE TAIL (wrapped)", (tail+n)%b.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.rcond.L.Lock()
|
b.rcond.L.Lock()
|
||||||
@@ -162,3 +194,14 @@ func (b *buffer) CommitTail(n int64) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// capDelta returns the difference between the head and tail.
|
||||||
|
func (b *buffer) capDelta(tail, head int64) int64 {
|
||||||
|
if head < tail {
|
||||||
|
debug.Println(b.id, "[B] capdelta (tail, head, v)", tail, head, b.size-tail+head)
|
||||||
|
return b.size - tail + head
|
||||||
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "[B] capdelta (tail, head, v)", tail, head, head-tail)
|
||||||
|
return head - tail
|
||||||
|
}
|
@@ -5,6 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -18,25 +20,31 @@ type Reader struct {
|
|||||||
|
|
||||||
// NewReader returns a pointer to a new Circular Reader.
|
// NewReader returns a pointer to a new Circular Reader.
|
||||||
func NewReader(size, block int64) *Reader {
|
func NewReader(size, block int64) *Reader {
|
||||||
|
b := newBuffer(size, block)
|
||||||
|
b.id = "\treader"
|
||||||
return &Reader{
|
return &Reader{
|
||||||
newBuffer(size, block),
|
b,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||||
// there is sufficient capacity to do so.
|
// there is sufficient capacity to do so.
|
||||||
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
|
||||||
|
debug.Println(b.id, "*[R] STARTING SPIN")
|
||||||
DONE:
|
DONE:
|
||||||
for {
|
for {
|
||||||
|
debug.Println(b.id, "*[R] SPIN")
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
|
debug.Println(b.id, "*[R] readFrom DONE")
|
||||||
break DONE
|
break DONE
|
||||||
}
|
}
|
||||||
fmt.Println("readingfrom")
|
|
||||||
// Wait until there's enough capacity in the buffer.
|
// Wait until there's enough capacity in the buffer.
|
||||||
st, err := b.awaitCapacity(b.block)
|
st, err := b.awaitCapacity(b.block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break DONE
|
debug.Println(b.id, "*[R] readFrom await capacity error, ", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the start and end indexes within the buffer.
|
// Find the start and end indexes within the buffer.
|
||||||
@@ -50,21 +58,23 @@ DONE:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read into the buffer between the start and end indexes only.
|
// Read into the buffer between the start and end indexes only.
|
||||||
|
debug.Println(b.id, "*[R] readFrom mapping (start, end)", start, end)
|
||||||
n, err := r.Read(b.buf[start:end])
|
n, err := r.Read(b.buf[start:end])
|
||||||
total += int64(n) // incr total bytes read.
|
total += int64(n) // incr total bytes read.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break DONE
|
break DONE
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("<", b.buf[start:end])
|
|
||||||
|
|
||||||
// Move the head forward.
|
// Move the head forward.
|
||||||
atomic.StoreInt64(&b.head, end)
|
debug.Println(b.id, "*[R] STORE HEAD", start+int64(n))
|
||||||
|
atomic.StoreInt64(&b.head, start+int64(n))
|
||||||
b.wcond.L.Lock()
|
b.wcond.L.Lock()
|
||||||
b.wcond.Broadcast()
|
b.wcond.Broadcast()
|
||||||
b.wcond.L.Unlock()
|
b.wcond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "*[R] FINISHED SPIN")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,40 +84,7 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
|
|||||||
head := atomic.LoadInt64(&b.head)
|
head := atomic.LoadInt64(&b.head)
|
||||||
|
|
||||||
// Wait until there's at least 1 byte of data.
|
// Wait until there's at least 1 byte of data.
|
||||||
/*
|
debug.Println(b.id, "[R] PEEKING (tail, head, n)", tail, head, n)
|
||||||
b.wcond.L.Lock()
|
|
||||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
b.wcond.Wait()
|
|
||||||
}
|
|
||||||
b.wcond.L.Unlock()
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Figure out if we can get all n bytes.
|
|
||||||
if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
|
||||||
return nil, ErrInsufficientBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've wrapped, get the bytes from the next and start.
|
|
||||||
if head < tail {
|
|
||||||
b.tmp = b.buf[tail:b.size]
|
|
||||||
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
|
||||||
return b.tmp, nil
|
|
||||||
} else {
|
|
||||||
return b.buf[tail : tail+n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeekWait waits until their are n bytes available to return.
|
|
||||||
func (b *Reader) PeekWait(n int64) ([]byte, error) {
|
|
||||||
tail := atomic.LoadInt64(&b.tail)
|
|
||||||
head := atomic.LoadInt64(&b.head)
|
|
||||||
|
|
||||||
// Wait until there's at least 1 byte of data.
|
|
||||||
b.wcond.L.Lock()
|
b.wcond.L.Lock()
|
||||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
@@ -116,22 +93,38 @@ func (b *Reader) PeekWait(n int64) ([]byte, error) {
|
|||||||
b.wcond.Wait()
|
b.wcond.Wait()
|
||||||
}
|
}
|
||||||
b.wcond.L.Unlock()
|
b.wcond.L.Unlock()
|
||||||
|
debug.Println(b.id, "[R] PEEKING available")
|
||||||
|
|
||||||
// Figure out if we can get all n bytes.
|
// Figure out if we can get all n bytes.
|
||||||
if head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
//if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
||||||
|
if head == tail || head < tail && b.size-tail+head < n {
|
||||||
|
debug.Println(b.id, "[R] insufficient bytes (tail, next, head)", tail, tail+n, head)
|
||||||
return nil, ErrInsufficientBytes
|
return nil, ErrInsufficientBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var d int64
|
||||||
|
if head < tail {
|
||||||
|
d = b.size - tail + head
|
||||||
|
} else {
|
||||||
|
d = head - tail
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek maximum of n bytes.
|
||||||
|
if d > n {
|
||||||
|
d = n
|
||||||
|
}
|
||||||
|
|
||||||
// If we've wrapped, get the bytes from the next and start.
|
// If we've wrapped, get the bytes from the next and start.
|
||||||
if head < tail {
|
if head < tail {
|
||||||
b.tmp = b.buf[tail:b.size]
|
b.tmp = b.buf[tail:b.size]
|
||||||
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
b.tmp = append(b.tmp, b.buf[:(tail+d)%b.size]...)
|
||||||
|
debug.Println(b.id, "[R] PEEKED wrapped (tail, next, buf)", tail, (tail+d)%b.size, b.tmp)
|
||||||
return b.tmp, nil
|
return b.tmp, nil
|
||||||
} else {
|
|
||||||
return b.buf[tail : tail+n], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
debug.Println(b.id, "[R] PEEKED (tail, next, buf)", tail, tail+d, b.buf[tail:tail+d])
|
||||||
|
return b.buf[tail : tail+d], nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads the next n bytes from the buffer. If n bytes are not
|
// Read reads the next n bytes from the buffer. If n bytes are not
|
||||||
@@ -143,19 +136,31 @@ func (b *Reader) Read(n int64) (p []byte, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until there's at least len(p) bytes to read.
|
// Wait until there's at least len(p) bytes to read.
|
||||||
|
debug.Println(b.id, "[R] awaiting read fill, (n)", n)
|
||||||
tail, err := b.awaitFilled(n)
|
tail, err := b.awaitFilled(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
debug.Println(b.id, "[R] filled err (tail, err)", tail, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "[R] filled (tail, next)", tail, tail+n)
|
||||||
|
|
||||||
// Once we have capacity, determine if the capacity wraps, and write it into
|
// Once we have capacity, determine if the capacity wraps, and write it into
|
||||||
// the buffer p.
|
// the buffer p.
|
||||||
if atomic.LoadInt64(&b.head) < tail {
|
if atomic.LoadInt64(&b.head) < tail {
|
||||||
b.tmp = b.buf[tail:b.size]
|
b.tmp = b.buf[tail:b.size]
|
||||||
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
||||||
|
atomic.StoreInt64(&b.tail, (tail+n)%b.size)
|
||||||
|
debug.Println(b.id, "[R] READ wrapped", b.tmp)
|
||||||
|
debug.Println(b.id, "[R] STORE TAIL wrapped", (tail+n)%b.size)
|
||||||
return b.tmp, nil
|
return b.tmp, nil
|
||||||
} else if tail+n < b.size {
|
} else if tail+n < b.size {
|
||||||
|
atomic.StoreInt64(&b.tail, tail+n)
|
||||||
|
debug.Println(b.id, "[R] READ", b.buf[tail:tail+n])
|
||||||
|
debug.Println(b.id, "[R] STORE TAIL", tail+n)
|
||||||
return b.buf[tail : tail+n], nil
|
return b.buf[tail : tail+n], nil
|
||||||
|
} else {
|
||||||
|
return p, fmt.Errorf("next byte extended size")
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
@@ -2,7 +2,6 @@ package circ
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -131,8 +130,6 @@ func TestRead(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
||||||
done := <-o
|
done := <-o
|
||||||
fmt.Println(done)
|
|
||||||
fmt.Println(string(done[0].([]byte)))
|
|
||||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||||
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
|
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
|
||||||
}
|
}
|
@@ -1,9 +1,10 @@
|
|||||||
package circ
|
package circ
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Writer is a circular buffer for writing data to an io.Writer.
|
// Writer is a circular buffer for writing data to an io.Writer.
|
||||||
@@ -13,8 +14,10 @@ type Writer struct {
|
|||||||
|
|
||||||
// NewWriter returns a pointer to a new Circular Writer.
|
// NewWriter returns a pointer to a new Circular Writer.
|
||||||
func NewWriter(size, block int64) *Writer {
|
func NewWriter(size, block int64) *Writer {
|
||||||
|
b := newBuffer(size, block)
|
||||||
|
b.id = "writer"
|
||||||
return &Writer{
|
return &Writer{
|
||||||
newBuffer(size, block),
|
b,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,36 +25,45 @@ func NewWriter(size, block int64) *Writer {
|
|||||||
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
|
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
|
||||||
var p []byte
|
var p []byte
|
||||||
var n int
|
var n int
|
||||||
DONE:
|
|
||||||
for {
|
for {
|
||||||
|
cd := b.capDelta(atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head))
|
||||||
|
debug.Println(b.id, "SPIN (tail, head, delta)", b.tail, b.head, cd)
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
err = io.EOF
|
if cd == 0 {
|
||||||
break DONE
|
err = io.EOF
|
||||||
|
return total, err
|
||||||
|
} else {
|
||||||
|
debug.Println("[W] //// capDelta not reached", cd)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fmt.Println("writingto")
|
|
||||||
|
|
||||||
// Peek until there's bytes to write using the Peek method
|
// Peek until there's bytes to write using the Peek method
|
||||||
// of the Reader type.
|
// of the Reader type.
|
||||||
p, err = (*Reader)(b).Peek(b.block)
|
p, err = (*Reader)(b).Peek(b.block)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break DONE
|
debug.Println(b.id, "[W] writeTo peek err (p, err)", p, err)
|
||||||
|
continue
|
||||||
|
//break DONE
|
||||||
}
|
}
|
||||||
|
debug.Println(b.id, "[W] PEEKED OK (p)", p)
|
||||||
fmt.Println(">", p)
|
|
||||||
|
|
||||||
// Write the peeked bytes to the io.Writer.
|
// Write the peeked bytes to the io.Writer.
|
||||||
n, err = w.Write(p)
|
n, err = w.Write(p)
|
||||||
total += int64(n)
|
total += int64(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break DONE
|
return total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug.Println(b.id, "[W] SENT (n, p)", n, p)
|
||||||
|
|
||||||
// Move the tail forward the bytes written.
|
// Move the tail forward the bytes written.
|
||||||
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
|
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
|
||||||
|
debug.Println(b.id, "[W] STORE TAIL", end)
|
||||||
atomic.StoreInt64(&b.tail, end)
|
atomic.StoreInt64(&b.tail, end)
|
||||||
b.rcond.L.Lock()
|
b.rcond.L.Lock()
|
||||||
b.rcond.Broadcast()
|
b.rcond.Broadcast()
|
||||||
b.rcond.L.Unlock()
|
b.rcond.L.Unlock()
|
||||||
|
debug.Println(b.id, "[W] writeTo unlocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -71,6 +83,16 @@ func (b *Writer) Write(p []byte) (nn int, err error) {
|
|||||||
|
|
||||||
// Write the outgoing bytes to the buffer, wrapping if necessary.
|
// Write the outgoing bytes to the buffer, wrapping if necessary.
|
||||||
nn = b.writeBytes(p)
|
nn = b.writeBytes(p)
|
||||||
|
debug.Println(b.id, "[W] bytes written (nn)", nn, p)
|
||||||
|
|
||||||
|
// Move the head forward.
|
||||||
|
next := (atomic.LoadInt64(&b.head) + int64(nn)) % b.size
|
||||||
|
atomic.StoreInt64(&b.head, next)
|
||||||
|
debug.Println(b.id, "[W] STORE HEAD", next)
|
||||||
|
|
||||||
|
b.wcond.L.Lock()
|
||||||
|
b.wcond.Broadcast()
|
||||||
|
b.wcond.L.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
@@ -1,7 +1,6 @@
|
|||||||
package circ
|
package circ
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
@@ -104,7 +103,6 @@ func TestWrite(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
done := <-o
|
done := <-o
|
||||||
fmt.Println(done)
|
|
||||||
require.Equal(t, tt.want, done[0].(int), "Wanted written mismatch [i:%d] %s", i, tt.desc)
|
require.Equal(t, tt.want, done[0].(int), "Wanted written mismatch [i:%d] %s", i, tt.desc)
|
||||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||||
|
|
21
clients.go
21
clients.go
@@ -1,6 +1,7 @@
|
|||||||
package mqtt
|
package mqtt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -77,8 +78,8 @@ func (cl *clients) getByListener(id string) []*client {
|
|||||||
type client struct {
|
type client struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
||||||
// p is a packets parser which reads incoming packets.
|
// p is a packets processor which reads and writes packets.
|
||||||
p *Parser
|
p *Processor
|
||||||
|
|
||||||
// ac is a pointer to an auth controller inherited from the listener.
|
// ac is a pointer to an auth controller inherited from the listener.
|
||||||
ac auth.Controller
|
ac auth.Controller
|
||||||
@@ -123,7 +124,7 @@ type client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newClient creates a new instance of client.
|
// newClient creates a new instance of client.
|
||||||
func newClient(p *Parser, pk *packets.ConnectPacket, ac auth.Controller) *client {
|
func newClient(p *Processor, pk *packets.ConnectPacket, ac auth.Controller) *client {
|
||||||
cl := &client{
|
cl := &client{
|
||||||
p: p,
|
p: p,
|
||||||
ac: ac,
|
ac: ac,
|
||||||
@@ -192,11 +193,19 @@ func (c *client) forgetSubscription(filter string) {
|
|||||||
// close attempts to gracefully close a client connection.
|
// close attempts to gracefully close a client connection.
|
||||||
func (cl *client) close() {
|
func (cl *client) close() {
|
||||||
cl.done.Do(func() {
|
cl.done.Do(func() {
|
||||||
close(cl.end) // Signal to stop listening for packets.
|
if cl.end != nil {
|
||||||
|
close(cl.end) // Signal to stop listening for packets in mqtt readClient loop.
|
||||||
|
fmt.Println("closed readClient end")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the processor.
|
||||||
|
cl.p.Stop()
|
||||||
|
fmt.Println("signalled stop, waiting")
|
||||||
|
fmt.Println("all processors ended")
|
||||||
|
|
||||||
// Close the network connection.
|
// Close the network connection.
|
||||||
cl.p.Conn.Close() // Error is irrelevant so can be ommitted here.
|
//cl.p.Conn.Close() // Error is irrelevant so can be ommitted here.
|
||||||
cl.p.Conn = nil
|
//cl.p.Conn = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,13 +2,13 @@ package mqtt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
//"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/mochi-co/mqtt/auth"
|
"github.com/mochi-co/mqtt/auth"
|
||||||
|
"github.com/mochi-co/mqtt/circ"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewClients(t *testing.T) {
|
func TestNewClients(t *testing.T) {
|
||||||
@@ -116,9 +116,7 @@ func BenchmarkClientsGetByListener(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClient(t *testing.T) {
|
func TestNewClient(t *testing.T) {
|
||||||
r, _ := net.Pipe()
|
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
|
||||||
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
|
|
||||||
r.Close()
|
|
||||||
pk := &packets.ConnectPacket{
|
pk := &packets.ConnectPacket{
|
||||||
FixedHeader: packets.FixedHeader{
|
FixedHeader: packets.FixedHeader{
|
||||||
Type: packets.Connect,
|
Type: packets.Connect,
|
||||||
@@ -153,9 +151,7 @@ func TestNewClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClientLWT(t *testing.T) {
|
func TestNewClientLWT(t *testing.T) {
|
||||||
r, _ := net.Pipe()
|
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
|
||||||
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
|
|
||||||
r.Close()
|
|
||||||
pk := &packets.ConnectPacket{
|
pk := &packets.ConnectPacket{
|
||||||
FixedHeader: packets.FixedHeader{
|
FixedHeader: packets.FixedHeader{
|
||||||
Type: packets.Connect,
|
Type: packets.Connect,
|
||||||
@@ -181,11 +177,8 @@ func TestNewClientLWT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkNewClient(b *testing.B) {
|
func BenchmarkNewClient(b *testing.B) {
|
||||||
r, _ := net.Pipe()
|
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
|
||||||
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
|
|
||||||
r.Close()
|
|
||||||
pk := new(packets.ConnectPacket)
|
pk := new(packets.ConnectPacket)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
newClient(p, pk, new(auth.Allow))
|
newClient(p, pk, new(auth.Allow))
|
||||||
}
|
}
|
||||||
@@ -244,15 +237,13 @@ func BenchmarkClientForgetSubscription(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClientClose(t *testing.T) {
|
func TestClientClose(t *testing.T) {
|
||||||
r, w := net.Pipe()
|
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
|
||||||
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
|
|
||||||
pk := &packets.ConnectPacket{
|
pk := &packets.ConnectPacket{
|
||||||
ClientIdentifier: "zen3",
|
ClientIdentifier: "zen3",
|
||||||
}
|
}
|
||||||
|
|
||||||
client := newClient(p, pk, new(auth.Allow))
|
client := newClient(p, pk, new(auth.Allow))
|
||||||
require.NotNil(t, client)
|
require.NotNil(t, client)
|
||||||
|
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
var ok bool
|
var ok bool
|
||||||
@@ -261,8 +252,6 @@ func TestClientClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.Equal(t, false, ok)
|
require.Equal(t, false, ok)
|
||||||
require.Nil(t, client.p.Conn)
|
require.Nil(t, client.p.Conn)
|
||||||
r.Close()
|
|
||||||
w.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInFlightSet(t *testing.T) {
|
func TestInFlightSet(t *testing.T) {
|
||||||
|
9
debug/log.go
Normal file
9
debug/log.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// +build !debug
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
// Println will be optimized away.
|
||||||
|
func Println(a ...interface{}) {}
|
||||||
|
|
||||||
|
// Printf will be optimized away.
|
||||||
|
func Printf(format string, a ...interface{}) {}
|
17
debug/log_debug.go
Normal file
17
debug/log_debug.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
// +build debug
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Println will be optimized away.
|
||||||
|
func Println(a ...interface{}) {
|
||||||
|
fmt.Println(a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf will be optimized away.
|
||||||
|
func Printf(format string, a ...interface{}) {
|
||||||
|
fmt.Printf(format, a...)
|
||||||
|
}
|
129
mqtt.go
129
mqtt.go
@@ -1,18 +1,18 @@
|
|||||||
package mqtt
|
package mqtt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
//"github.com/davecgh/go-spew/spew"
|
//"github.com/davecgh/go-spew/spew"
|
||||||
|
|
||||||
"github.com/mochi-co/mqtt/auth"
|
"github.com/mochi-co/mqtt/auth"
|
||||||
|
"github.com/mochi-co/mqtt/circ"
|
||||||
"github.com/mochi-co/mqtt/listeners"
|
"github.com/mochi-co/mqtt/listeners"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
"github.com/mochi-co/mqtt/processors/circ"
|
|
||||||
"github.com/mochi-co/mqtt/topics"
|
"github.com/mochi-co/mqtt/topics"
|
||||||
"github.com/mochi-co/mqtt/topics/trie"
|
"github.com/mochi-co/mqtt/topics/trie"
|
||||||
)
|
)
|
||||||
@@ -36,9 +36,6 @@ var (
|
|||||||
ErrConnectionClosed = errors.New("Connection not open")
|
ErrConnectionClosed = errors.New("Connection not open")
|
||||||
ErrNoData = errors.New("No data")
|
ErrNoData = errors.New("No data")
|
||||||
ErrACLNotAuthorized = errors.New("ACL not authorized")
|
ErrACLNotAuthorized = errors.New("ACL not authorized")
|
||||||
|
|
||||||
// rwBufSize is the size of client read/write buffers.
|
|
||||||
rwBufSize = 512
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server is an MQTT broker server.
|
// Server is an MQTT broker server.
|
||||||
@@ -91,16 +88,19 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
|
|||||||
// Serve begins the event loops for establishing client connections on all
|
// Serve begins the event loops for establishing client connections on all
|
||||||
// attached listeners.
|
// attached listeners.
|
||||||
func (s *Server) Serve() error {
|
func (s *Server) Serve() error {
|
||||||
s.listeners.ServeAll(s.EstablishConnection2)
|
s.listeners.ServeAll(s.EstablishConnection)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EstablishConnection establishes a new client connection with the broker.
|
// EstablishConnection establishes a new client connection with the broker.
|
||||||
func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller) error {
|
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
||||||
var size int64 = 1024 * 256
|
p := NewProcessor(
|
||||||
|
c,
|
||||||
p := NewProcessor(c, circ.NewReader(size), circ.NewWriter(size))
|
circ.NewReader(0, 0),
|
||||||
|
circ.NewWriter(0, 0),
|
||||||
|
)
|
||||||
|
p.Start()
|
||||||
|
|
||||||
// Pull the header from the first packet and check for a CONNECT message.
|
// Pull the header from the first packet and check for a CONNECT message.
|
||||||
fh := new(packets.FixedHeader)
|
fh := new(packets.FixedHeader)
|
||||||
@@ -109,9 +109,94 @@ func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller
|
|||||||
return ErrReadConnectFixedHeader
|
return ErrReadConnectFixedHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Read the first packet expecting a CONNECT message.
|
||||||
|
pk, err := p.Read()
|
||||||
|
if err != nil {
|
||||||
|
return ErrReadConnectPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure first packet is a connect packet.
|
||||||
|
msg, ok := pk.(*packets.ConnectPacket)
|
||||||
|
if !ok {
|
||||||
|
return ErrFirstPacketInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the packet conforms to MQTT CONNECT specifications.
|
||||||
|
retcode, _ := msg.Validate()
|
||||||
|
if retcode != packets.Accepted {
|
||||||
|
return ErrReadConnectInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a username and password has been provided, perform authentication.
|
||||||
|
if msg.Username != "" && !ac.Authenticate(msg.Username, msg.Password) {
|
||||||
|
retcode = packets.CodeConnectNotAuthorised
|
||||||
|
return ErrConnectNotAuthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new client with this connection.
|
||||||
|
client := newClient(p, msg, ac)
|
||||||
|
client.listener = lid // Note the listener the client is connected to.
|
||||||
|
|
||||||
|
// If it's not a clean session and a client already exists with the same
|
||||||
|
// id, inherit the session.
|
||||||
|
var sessionPresent bool
|
||||||
|
if existing, ok := s.clients.get(msg.ClientIdentifier); ok {
|
||||||
|
existing.Lock()
|
||||||
|
existing.close()
|
||||||
|
if msg.CleanSession {
|
||||||
|
for k := range existing.subscriptions {
|
||||||
|
s.topics.Unsubscribe(k, existing.id)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client.inFlight = existing.inFlight // Inherit from existing session.
|
||||||
|
client.subscriptions = existing.subscriptions
|
||||||
|
sessionPresent = true
|
||||||
|
}
|
||||||
|
existing.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new client to the clients manager.
|
||||||
|
s.clients.add(client)
|
||||||
|
|
||||||
|
// Send a CONNACK back to the client.
|
||||||
|
err = s.writeClient(client, &packets.ConnackPacket{
|
||||||
|
FixedHeader: packets.FixedHeader{
|
||||||
|
Type: packets.Connack,
|
||||||
|
},
|
||||||
|
SessionPresent: sessionPresent,
|
||||||
|
ReturnCode: retcode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("WRITECLIENT err", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("----------------")
|
||||||
|
|
||||||
|
// Resend any unacknowledged QOS messages still pending for the client.
|
||||||
|
/*err = s.resendInflight(client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}*/
|
||||||
|
|
||||||
|
// Block and listen for more packets, and end if an error or nil packet occurs.
|
||||||
|
var sendLWT bool
|
||||||
|
err = s.readClient(client)
|
||||||
|
if err != nil {
|
||||||
|
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("** stop and wait", err)
|
||||||
|
|
||||||
|
// Publish last will and testament.
|
||||||
|
s.closeClient(client, sendLWT)
|
||||||
|
|
||||||
|
fmt.Println("stopped")
|
||||||
|
|
||||||
|
return err // capture err or nil from readClient.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
// EstablishConnection establishes a new client connection with the broker.
|
// EstablishConnection establishes a new client connection with the broker.
|
||||||
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
|
||||||
|
|
||||||
@@ -205,6 +290,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
|
|||||||
|
|
||||||
return err // capture err or nil from readClient.
|
return err // capture err or nil from readClient.
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
// resendInflight republishes any inflight messages to the client.
|
// resendInflight republishes any inflight messages to the client.
|
||||||
func (s *Server) resendInflight(cl *client) error {
|
func (s *Server) resendInflight(cl *client) error {
|
||||||
@@ -226,7 +312,7 @@ func (s *Server) readClient(cl *client) error {
|
|||||||
var err error
|
var err error
|
||||||
var pk packets.Packet
|
var pk packets.Packet
|
||||||
fh := new(packets.FixedHeader)
|
fh := new(packets.FixedHeader)
|
||||||
|
fmt.Println("STARTING READ CLIENT LOOP")
|
||||||
DONE:
|
DONE:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -250,15 +336,20 @@ DONE:
|
|||||||
|
|
||||||
// If it's a disconnect packet, begin the close process.
|
// If it's a disconnect packet, begin the close process.
|
||||||
if fh.Type == packets.Disconnect {
|
if fh.Type == packets.Disconnect {
|
||||||
return nil
|
fmt.Println(" X got disconnect")
|
||||||
|
break DONE
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise read in the packet payload.
|
// Otherwise read in the packet payload.
|
||||||
pk, err = cl.p.Read()
|
pk, err = cl.p.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrReadPacketPayload
|
fmt.Println("RC READ ERR", err)
|
||||||
|
//return ErrReadPacketPayload
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("READ PACKET", pk)
|
||||||
|
|
||||||
// Validate the packet if necessary.
|
// Validate the packet if necessary.
|
||||||
_, err = pk.Validate()
|
_, err = pk.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -270,6 +361,7 @@ DONE:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("READCLIENT COMPLETED")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -555,8 +647,8 @@ func (s *Server) processUnsubscribe(cl *client, pk *packets.UnsubscribePacket) e
|
|||||||
|
|
||||||
// writeClient writes packets to a client connection.
|
// writeClient writes packets to a client connection.
|
||||||
func (s *Server) writeClient(cl *client, pk packets.Packet) error {
|
func (s *Server) writeClient(cl *client, pk packets.Packet) error {
|
||||||
cl.p.Lock()
|
cl.p.W.Lock()
|
||||||
defer cl.p.Unlock()
|
defer cl.p.W.Unlock()
|
||||||
|
|
||||||
// Ensure Writer is open.
|
// Ensure Writer is open.
|
||||||
if cl.p.W == nil {
|
if cl.p.W == nil {
|
||||||
@@ -575,11 +667,6 @@ func (s *Server) writeClient(cl *client, pk packets.Packet) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cl.p.W.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh deadline to keep the connection alive.
|
// Refresh deadline to keep the connection alive.
|
||||||
cl.p.RefreshDeadline(cl.keepalive)
|
cl.p.RefreshDeadline(cl.keepalive)
|
||||||
|
|
||||||
|
90
mqtt_test.go
90
mqtt_test.go
@@ -2,6 +2,7 @@ package mqtt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -10,11 +11,18 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/mochi-co/mqtt/auth"
|
"github.com/mochi-co/mqtt/auth"
|
||||||
|
"github.com/mochi-co/mqtt/circ"
|
||||||
"github.com/mochi-co/mqtt/listeners"
|
"github.com/mochi-co/mqtt/listeners"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
"github.com/mochi-co/mqtt/topics"
|
//"github.com/mochi-co/mqtt/topics"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// bufferSize is the buffer size of the processors.
|
||||||
|
var bufferSize int64 = 512
|
||||||
|
|
||||||
|
// blockSize is the block size of the processors.
|
||||||
|
var blockSize int64 = 128
|
||||||
|
|
||||||
type quietWriter struct {
|
type quietWriter struct {
|
||||||
b []byte
|
b []byte
|
||||||
f [][]byte
|
f [][]byte
|
||||||
@@ -45,9 +53,8 @@ func (q *quietWriter) Flush() error {
|
|||||||
|
|
||||||
func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) {
|
func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) {
|
||||||
s = New()
|
s = New()
|
||||||
r, w = net.Pipe()
|
|
||||||
cl = newClient(
|
cl = newClient(
|
||||||
NewParser(r, newBufioReader(r), newBufioWriter(w)),
|
NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)),
|
||||||
&packets.ConnectPacket{
|
&packets.ConnectPacket{
|
||||||
ClientIdentifier: id,
|
ClientIdentifier: id,
|
||||||
},
|
},
|
||||||
@@ -178,10 +185,64 @@ func BenchmarkServerClose(b *testing.B) {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
||||||
r2, w2 := net.Pipe()
|
|
||||||
s := New()
|
s := New()
|
||||||
|
r, w := net.Pipe()
|
||||||
|
o := make(chan error)
|
||||||
|
go func() {
|
||||||
|
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
w.Write([]byte{
|
||||||
|
byte(packets.Connect << 4), 15, // Fixed header
|
||||||
|
0, 4, // Protocol Name - MSB+LSB
|
||||||
|
'M', 'Q', 'T', 'T', // Protocol Name
|
||||||
|
4, // Protocol Version
|
||||||
|
2, // Packet Flags - clean session
|
||||||
|
0, 45, // Keepalive
|
||||||
|
0, 3, // Client ID - MSB+LSB
|
||||||
|
'z', 'e', 'n', // Client ID "zen"
|
||||||
|
})
|
||||||
|
w.Write([]byte{byte(packets.Disconnect << 4), 0})
|
||||||
|
}()
|
||||||
|
|
||||||
|
recv := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
buf, err := ioutil.ReadAll(w)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
recv <- buf
|
||||||
|
}()
|
||||||
|
|
||||||
|
//time.Sleep(10 * time.Millisecond)
|
||||||
|
s.clients.Lock()
|
||||||
|
for _, v := range s.clients.internal {
|
||||||
|
v.close()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
s.clients.Unlock()
|
||||||
|
|
||||||
|
errx := <-o
|
||||||
|
require.NoError(t, errx)
|
||||||
|
require.Equal(t, []byte{
|
||||||
|
byte(packets.Connack << 4), 2,
|
||||||
|
0, packets.Accepted,
|
||||||
|
}, <-recv)
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
||||||
|
r2, _ := net.Pipe()
|
||||||
|
s := New()
|
||||||
|
px := NewProcessor(r2, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
|
||||||
|
px.Start()
|
||||||
s.clients.internal["zen"] = newClient(
|
s.clients.internal["zen"] = newClient(
|
||||||
NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
|
px,
|
||||||
&packets.ConnectPacket{ClientIdentifier: "zen"},
|
&packets.ConnectPacket{ClientIdentifier: "zen"},
|
||||||
new(auth.Allow),
|
new(auth.Allow),
|
||||||
)
|
)
|
||||||
@@ -228,6 +289,8 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
|
|||||||
w.Close()
|
w.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
|
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
|
||||||
r2, w2 := net.Pipe()
|
r2, w2 := net.Pipe()
|
||||||
s := New()
|
s := New()
|
||||||
@@ -496,8 +559,8 @@ func TestResendInflightWriteError(t *testing.T) {
|
|||||||
|
|
||||||
* Server Read Client
|
* Server Read Client
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
/*
|
||||||
func TestServerReadClientOK(t *testing.T) {
|
func TestServerReadClientOK(t *testing.T) {
|
||||||
s, r, w, cl := setupClient("zen")
|
s, r, w, cl := setupClient("zen")
|
||||||
|
|
||||||
@@ -603,8 +666,8 @@ func TestServerReadClientBadPacketValidation(t *testing.T) {
|
|||||||
|
|
||||||
* Server Write Client
|
* Server Write Client
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
/*
|
||||||
func TestServerWriteClient(t *testing.T) {
|
func TestServerWriteClient(t *testing.T) {
|
||||||
s, _, _, cl := setupClient("zen")
|
s, _, _, cl := setupClient("zen")
|
||||||
cl.p.W = new(quietWriter)
|
cl.p.W = new(quietWriter)
|
||||||
@@ -676,8 +739,8 @@ func TestServerWriteClientWriteError(t *testing.T) {
|
|||||||
|
|
||||||
* Server Close Client
|
* Server Close Client
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
/*
|
||||||
func TestServerCloseClient(t *testing.T) { // as opposed to client.close
|
func TestServerCloseClient(t *testing.T) { // as opposed to client.close
|
||||||
s, _, _, cl := setupClient("zen")
|
s, _, _, cl := setupClient("zen")
|
||||||
s.clients.add(cl)
|
s.clients.add(cl)
|
||||||
@@ -748,8 +811,8 @@ func TestServerCloseClientLWTWriteError(t *testing.T) { // as opposed to client.
|
|||||||
|
|
||||||
* Server Process Packets
|
* Server Process Packets
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
/*
|
||||||
func TestServerProcessPacket(t *testing.T) {
|
func TestServerProcessPacket(t *testing.T) {
|
||||||
s, _, _, cl := setupClient("zen")
|
s, _, _, cl := setupClient("zen")
|
||||||
err := s.processPacket(cl, nil)
|
err := s.processPacket(cl, nil)
|
||||||
@@ -1368,3 +1431,4 @@ func TestServerProcessUnsubscribeWriteError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
192
parser.go
192
parser.go
@@ -1,192 +0,0 @@
|
|||||||
package mqtt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mochi-co/mqtt/packets"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BufWriter is an interface for satisfying a bufio.Writer. This is mainly
|
|
||||||
// in place to allow testing.
|
|
||||||
type BufWriter interface {
|
|
||||||
|
|
||||||
// Write writes a byte buffer.
|
|
||||||
Write(p []byte) (nn int, err error)
|
|
||||||
|
|
||||||
// Flush flushes the buffer.
|
|
||||||
Flush() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parser is an MQTT packet parser that reads and writes MQTT payloads to a
|
|
||||||
// buffered IO stream.
|
|
||||||
type Parser struct {
|
|
||||||
sync.RWMutex
|
|
||||||
|
|
||||||
// Conn is the net.Conn used to establish the connection.
|
|
||||||
Conn net.Conn
|
|
||||||
|
|
||||||
// R is a bufio reader for peeking and reading incoming packets.
|
|
||||||
R *bufio.Reader
|
|
||||||
|
|
||||||
// W is a bufio writer for writing outgoing packets.
|
|
||||||
W BufWriter
|
|
||||||
|
|
||||||
// FixedHeader is the fixed header from the last seen packet.
|
|
||||||
FixedHeader packets.FixedHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewParser returns an instance of Parser for a connection.
|
|
||||||
func NewParser(c net.Conn, r *bufio.Reader, w BufWriter) *Parser {
|
|
||||||
return &Parser{
|
|
||||||
Conn: c,
|
|
||||||
R: r,
|
|
||||||
W: w,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
|
||||||
func (p *Parser) RefreshDeadline(keepalive uint16) {
|
|
||||||
if p.Conn != nil {
|
|
||||||
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
|
|
||||||
p.Conn.SetDeadline(time.Now().Add(expiry))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
|
||||||
func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|
||||||
|
|
||||||
// Peek the maximum message type and flags, and length.
|
|
||||||
peeked, err := p.R.Peek(1)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unpack message type and flags from byte 1.
|
|
||||||
err = fh.Decode(peeked[0])
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
// @SPEC [MQTT-2.2.2-2]
|
|
||||||
// If invalid flags are received, the receiver MUST close the Network Connection.
|
|
||||||
return packets.ErrInvalidFlags
|
|
||||||
}
|
|
||||||
|
|
||||||
// The remaining length value can be up to 5 bytes. Peek through each byte
|
|
||||||
// looking for continue values, and if found increase the peek. Otherwise
|
|
||||||
// decode the bytes that were legit.
|
|
||||||
//p.fhBuffer = p.fhBuffer[:0]
|
|
||||||
buf := make([]byte, 0, 6)
|
|
||||||
i := 1
|
|
||||||
b := 2
|
|
||||||
for ; b < 6; b++ {
|
|
||||||
peeked, err = p.R.Peek(b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the byte to the length bytes slice.
|
|
||||||
buf = append(buf, peeked[i])
|
|
||||||
|
|
||||||
// If it's not a continuation flag, end here.
|
|
||||||
if peeked[i] < 128 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// If i has reached 4 without a length terminator, throw a protocol violation.
|
|
||||||
i++
|
|
||||||
if i == 4 {
|
|
||||||
return packets.ErrOversizedLengthIndicator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate and store the remaining length.
|
|
||||||
rem, _ := binary.Uvarint(buf)
|
|
||||||
fh.Remaining = int(rem)
|
|
||||||
|
|
||||||
// Discard the number of used length bytes + first byte.
|
|
||||||
p.R.Discard(b)
|
|
||||||
|
|
||||||
// Set the fixed header in the parser.
|
|
||||||
p.FixedHeader = *fh
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads the remaining buffer into an MQTT packet.
|
|
||||||
func (p *Parser) Read() (pk packets.Packet, err error) {
|
|
||||||
|
|
||||||
switch p.FixedHeader.Type {
|
|
||||||
case packets.Connect:
|
|
||||||
pk = &packets.ConnectPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Connack:
|
|
||||||
pk = &packets.ConnackPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Publish:
|
|
||||||
pk = &packets.PublishPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Puback:
|
|
||||||
pk = &packets.PubackPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Pubrec:
|
|
||||||
pk = &packets.PubrecPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Pubrel:
|
|
||||||
pk = &packets.PubrelPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Pubcomp:
|
|
||||||
pk = &packets.PubcompPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Subscribe:
|
|
||||||
pk = &packets.SubscribePacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Suback:
|
|
||||||
pk = &packets.SubackPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Unsubscribe:
|
|
||||||
pk = &packets.UnsubscribePacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Unsuback:
|
|
||||||
pk = &packets.UnsubackPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Pingreq:
|
|
||||||
pk = &packets.PingreqPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Pingresp:
|
|
||||||
pk = &packets.PingrespPacket{FixedHeader: p.FixedHeader}
|
|
||||||
case packets.Disconnect:
|
|
||||||
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
|
|
||||||
default:
|
|
||||||
return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to peek the rest of the packet.
|
|
||||||
peeked := true
|
|
||||||
|
|
||||||
bt, err := p.R.Peek(p.FixedHeader.Remaining)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
// Only try to continue if reading is still possible.
|
|
||||||
if err != bufio.ErrBufferFull {
|
|
||||||
return pk, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it didn't work, read the buffer directly, and if that still doesn't
|
|
||||||
// work, then throw an error.
|
|
||||||
peeked = false
|
|
||||||
bt = make([]byte, p.FixedHeader.Remaining)
|
|
||||||
_, err := io.ReadFull(p.R, bt)
|
|
||||||
if err != nil {
|
|
||||||
return pk, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If peeking was successful, discard the rest of the packet now it's been read.
|
|
||||||
if peeked {
|
|
||||||
p.R.Discard(p.FixedHeader.Remaining)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
|
||||||
// otherwise the next packet will change the data of this one.
|
|
||||||
// 🚨 DANGER!! This line is super important. If the bytes being decoded are not
|
|
||||||
// in their own memory space, packets will get corrupted all over the place.
|
|
||||||
err = pk.Decode(append([]byte{}, bt[:]...)) // <--- MUST BE A COPY.
|
|
||||||
if err != nil {
|
|
||||||
return pk, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
528
parser_test.go
528
parser_test.go
@@ -1,528 +0,0 @@
|
|||||||
package mqtt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
//"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/mochi-co/mqtt/packets"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newBufioReader(c net.Conn) *bufio.Reader {
|
|
||||||
return bufio.NewReaderSize(c, 512)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBufioWriter(c net.Conn) *bufio.Writer {
|
|
||||||
return bufio.NewWriterSize(c, 512)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewParser(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
require.NotNil(t, p.R)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkNewParser(b *testing.B) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
r, w := new(bufio.Reader), new(bufio.Writer)
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
NewParser(conn, r, w)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefreshDeadline(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
dl := p.Conn.(*MockNetConn).Deadline
|
|
||||||
p.RefreshDeadline(10)
|
|
||||||
|
|
||||||
require.NotEqual(t, dl, p.Conn.(*MockNetConn).Deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRefreshDeadline(b *testing.B) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
p.RefreshDeadline(10)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
type fixedHeaderTable struct {
|
|
||||||
rawBytes []byte
|
|
||||||
header packets.FixedHeader
|
|
||||||
packetError bool
|
|
||||||
flagError bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var fixedHeaderExpected = []fixedHeaderTable{
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Connect << 4, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Connack << 4, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Connack, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Publish << 4, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Publish, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Publish<<4 | 1<<1, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Publish, false, 1, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Publish<<4 | 1<<1 | 1, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Publish, false, 1, true, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{packets.Publish<<4 | 2<<1, 0x00},
|
|
||||||
header: packets.FixedHeader{packets.Publish, false, 2, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
|
||||||
header: FixedHeader{Publish, false, 2, true, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
|
||||||
header: FixedHeader{Publish, true, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
|
||||||
header: FixedHeader{Publish, true, 0, true, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
|
||||||
header: FixedHeader{Publish, true, 1, true, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
|
||||||
header: FixedHeader{Publish, true, 2, true, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Puback << 4, 0x00},
|
|
||||||
header: FixedHeader{Puback, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
|
||||||
header: FixedHeader{Pubrec, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
|
||||||
header: FixedHeader{Pubrel, false, 1, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
|
||||||
header: FixedHeader{Pubcomp, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
|
||||||
header: FixedHeader{Subscribe, false, 1, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Suback << 4, 0x00},
|
|
||||||
header: FixedHeader{Suback, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
|
||||||
header: FixedHeader{Unsubscribe, false, 1, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
|
||||||
header: FixedHeader{Unsuback, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
|
||||||
header: FixedHeader{Pingreq, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
|
||||||
header: FixedHeader{Pingresp, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
|
||||||
header: FixedHeader{Disconnect, false, 0, false, 0},
|
|
||||||
},
|
|
||||||
|
|
||||||
// remaining length
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish << 4, 0x0a},
|
|
||||||
header: FixedHeader{Publish, false, 0, false, 10},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
|
||||||
header: FixedHeader{Publish, false, 0, false, 512},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
|
||||||
header: FixedHeader{Publish, false, 0, false, 978},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
|
||||||
header: FixedHeader{Publish, false, 0, false, 20102},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
|
||||||
header: FixedHeader{Publish, false, 0, false, 333333333},
|
|
||||||
packetError: true,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Invalid flags for packet
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
|
||||||
header: FixedHeader{Connect, true, 0, false, 0},
|
|
||||||
flagError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
|
||||||
header: FixedHeader{Connect, false, 1, false, 0},
|
|
||||||
flagError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
|
||||||
header: FixedHeader{Connect, false, 0, true, 0},
|
|
||||||
flagError: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func TestParserReadFixedHeader(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
|
|
||||||
// Test null data.
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
fh := new(packets.FixedHeader)
|
|
||||||
err := p.ReadFixedHeader(fh)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
// Test insufficient peeking.
|
|
||||||
fh = new(packets.FixedHeader)
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4}))
|
|
||||||
err = p.ReadFixedHeader(fh)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
fh = new(packets.FixedHeader)
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4, 0x00}))
|
|
||||||
err = p.ReadFixedHeader(fh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
|
|
||||||
// Test expected bytes.
|
|
||||||
for i, wanted := range fixedHeaderExpected {
|
|
||||||
fh := new(packets.FixedHeader)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
b := wanted.rawBytes
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader(b))
|
|
||||||
|
|
||||||
err := p.ReadFixedHeader(fh)
|
|
||||||
if wanted.packetError || wanted.flagError {
|
|
||||||
require.Error(t, err, "Expected error [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
require.Equal(t, wanted.header.Type, p.FixedHeader.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
require.Equal(t, wanted.header.Dup, p.FixedHeader.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
require.Equal(t, wanted.header.Qos, p.FixedHeader.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
require.Equal(t, wanted.header.Retain, p.FixedHeader.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkReadFixedHeader(b *testing.B) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
fh := new(packets.FixedHeader)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
var rn bytes.Reader = *bytes.NewReader([]byte{packets.Connect << 4, 0x00})
|
|
||||||
var rc bytes.Reader
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
rc = rn
|
|
||||||
p.R.Reset(&rc)
|
|
||||||
err := p.ReadFixedHeader(fh)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{
|
|
||||||
byte(packets.Publish << 4), 18, // Fixed header
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'a', '/', 'b', '/', 'c', // Topic Name
|
|
||||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
|
||||||
}))
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
pko, err := p.Read()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, &packets.PublishPacket{
|
|
||||||
FixedHeader: packets.FixedHeader{
|
|
||||||
Type: packets.Publish,
|
|
||||||
Remaining: 18,
|
|
||||||
},
|
|
||||||
TopicName: "a/b/c",
|
|
||||||
Payload: []byte("hello mochi"),
|
|
||||||
}, pko)
|
|
||||||
|
|
||||||
/*
|
|
||||||
for code, pt := range expectedPackets {
|
|
||||||
for i, wanted := range pt {
|
|
||||||
if wanted.primary {
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
b := wanted.rawBytes
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader(b))
|
|
||||||
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
if wanted.failFirst != nil {
|
|
||||||
require.Error(t, err, "Expected error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "Error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
}
|
|
||||||
|
|
||||||
pko, err := p.Read()
|
|
||||||
|
|
||||||
if wanted.expect != nil {
|
|
||||||
require.Error(t, err, "Expected error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
if err != nil {
|
|
||||||
require.Equal(t, err, wanted.expect, "Mismatched packet error [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "Error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
require.Equal(t, wanted.packet, pko, "Mismatched packet final [i:%d] %s - %s", i, wanted.desc, Names[code])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadFail(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{
|
|
||||||
byte(packets.Publish << 4), 3, // Fixed header
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'a', '/',
|
|
||||||
}))
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = p.Read()
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRead(b *testing.B) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
pkb := []byte{
|
|
||||||
byte(packets.Publish << 4), 18, // Fixed header
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'a', '/', 'b', '/', 'c', // Topic Name
|
|
||||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
|
|
||||||
}
|
|
||||||
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader(pkb))
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var rn bytes.Reader = *bytes.NewReader(pkb)
|
|
||||||
var rc bytes.Reader
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
rc = rn
|
|
||||||
p.R.Reset(&rc)
|
|
||||||
p.R.Discard(2)
|
|
||||||
_, err := p.Read()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is a super important test. It checks whether or not subsequent packets
|
|
||||||
// mutate each other. This happens when you use a single byte buffer for decoding
|
|
||||||
// multiple packets.
|
|
||||||
func TestReadPacketNoOverwrite(t *testing.T) {
|
|
||||||
pk1 := []byte{
|
|
||||||
byte(packets.Publish << 4), 12, // Fixed header
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'a', '/', 'b', '/', 'c', // Topic Name
|
|
||||||
'h', 'e', 'l', 'l', 'o', // Payload
|
|
||||||
}
|
|
||||||
|
|
||||||
pk2 := []byte{
|
|
||||||
byte(packets.Publish << 4), 14, // Fixed header
|
|
||||||
0, 5, // Topic Name - LSB+MSB
|
|
||||||
'x', '/', 'y', '/', 'z', // Topic Name
|
|
||||||
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
|
|
||||||
}
|
|
||||||
|
|
||||||
r, w := net.Pipe()
|
|
||||||
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
|
|
||||||
go func() {
|
|
||||||
w.Write(pk1)
|
|
||||||
w.Write(pk2)
|
|
||||||
w.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
o1, err := p.Read()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload)
|
|
||||||
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:])
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = p.ReadFixedHeader(&fh)
|
|
||||||
require.NoError(t, err)
|
|
||||||
o2, err := p.Read()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*packets.PublishPacket).Payload)
|
|
||||||
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload, "o1 payload was mutated")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadPacketNil(t *testing.T) {
|
|
||||||
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
// Check for un-specified packet.
|
|
||||||
// Create a ping request packet with a false fixedheader type code.
|
|
||||||
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
|
|
||||||
|
|
||||||
pk.FixedHeader.Type = 99
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{0, 0}))
|
|
||||||
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
_, err = p.Read()
|
|
||||||
|
|
||||||
require.Error(t, err, "Expected error reading packet")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadPacketReadOverflow(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
// Check for un-specified packet.
|
|
||||||
// Create a ping request packet with a false fixedheader type code.
|
|
||||||
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
|
|
||||||
|
|
||||||
pk.FixedHeader.Type = 99
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
|
|
||||||
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
|
|
||||||
p.FixedHeader.Remaining = 999999 // overflow buffer
|
|
||||||
_, err = p.Read()
|
|
||||||
|
|
||||||
require.Error(t, err, "Expected error reading packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadPacketReadAllFail(t *testing.T) {
|
|
||||||
conn := new(MockNetConn)
|
|
||||||
var fh packets.FixedHeader
|
|
||||||
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
|
|
||||||
|
|
||||||
// Check for un-specified packet.
|
|
||||||
// Create a ping request packet with a false fixedheader type code.
|
|
||||||
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
|
|
||||||
|
|
||||||
pk.FixedHeader.Type = 99
|
|
||||||
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
|
|
||||||
|
|
||||||
err := p.ReadFixedHeader(&fh)
|
|
||||||
|
|
||||||
p.FixedHeader.Remaining = 1 // overflow buffer
|
|
||||||
_, err = p.Read()
|
|
||||||
|
|
||||||
require.Error(t, err, "Expected error reading packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
// MockNetConn satisfies the net.Conn interface.
|
|
||||||
type MockNetConn struct {
|
|
||||||
ID string
|
|
||||||
Deadline time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads bytes from the net io.reader.
|
|
||||||
func (m *MockNetConn) Read(b []byte) (n int, err error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read writes bytes to the net io.writer.
|
|
||||||
func (m *MockNetConn) Write(b []byte) (n int, err error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the net.Conn connection.
|
|
||||||
func (m *MockNetConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr returns the local address of the request.
|
|
||||||
func (m *MockNetConn) LocalAddr() net.Addr {
|
|
||||||
return new(MockNetAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr returns the remove address of the request.
|
|
||||||
func (m *MockNetConn) RemoteAddr() net.Addr {
|
|
||||||
return new(MockNetAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDeadline sets the request deadline.
|
|
||||||
func (m *MockNetConn) SetDeadline(t time.Time) error {
|
|
||||||
m.Deadline = t
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetReadDeadline sets the read deadline.
|
|
||||||
func (m *MockNetConn) SetReadDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWriteDeadline sets the write deadline.
|
|
||||||
func (m *MockNetConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockNetAddr satisfies net.Addr interface.
|
|
||||||
type MockNetAddr struct{}
|
|
||||||
|
|
||||||
// Network returns the network protocol.
|
|
||||||
func (m *MockNetAddr) Network() string {
|
|
||||||
return "tcp"
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the network address.
|
|
||||||
func (m *MockNetAddr) String() string {
|
|
||||||
return "127.0.0.1"
|
|
||||||
}
|
|
||||||
*/
|
|
@@ -1,67 +0,0 @@
|
|||||||
package pools
|
|
||||||
|
|
||||||
/*
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BufioReadersPool is a pool of bufio.Reader.
|
|
||||||
type BufioReadersPool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBufioReadersPool returns a sync.pool of bufio.Reader.
|
|
||||||
func NewBufioReadersPool(size int) BufioReadersPool {
|
|
||||||
return BufioReadersPool{
|
|
||||||
pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return bufio.NewReaderSize(nil, size)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns a pooled bufio.Reader.
|
|
||||||
func (b BufioReadersPool) Get(r io.Reader) *bufio.Reader {
|
|
||||||
buf := b.pool.Get().(*bufio.Reader)
|
|
||||||
buf.Reset(r)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts the bufio.Reader back into the pool.
|
|
||||||
func (b BufioReadersPool) Put(x *bufio.Reader) {
|
|
||||||
x.Reset(nil)
|
|
||||||
b.pool.Put(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BufioWritersPool is a pool of bufio.Writer.
|
|
||||||
type BufioWritersPool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBufioWritersPool returns a sync.pool of bufio.Writer.
|
|
||||||
func NewBufioWritersPool(size int) BufioWritersPool {
|
|
||||||
return BufioWritersPool{
|
|
||||||
pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return bufio.NewWriterSize(nil, size)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns a pooled bufio.Writer.
|
|
||||||
func (b BufioWritersPool) Get(r io.Writer) *bufio.Writer {
|
|
||||||
buf := b.pool.Get().(*bufio.Writer)
|
|
||||||
buf.Reset(r)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts the bufio.Writer back into the pool.
|
|
||||||
func (b BufioWritersPool) Put(x *bufio.Writer) {
|
|
||||||
x.Reset(nil)
|
|
||||||
b.pool.Put(x)
|
|
||||||
}
|
|
||||||
*/
|
|
@@ -1,115 +0,0 @@
|
|||||||
package pools
|
|
||||||
|
|
||||||
/*
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewBufioReadersPool(t *testing.T) {
|
|
||||||
bpool := NewBufioReadersPool(16)
|
|
||||||
require.NotNil(t, bpool.pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkNewBufioReadersPool(b *testing.B) {
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
NewBufioReadersPool(16)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBufioReadersGetPut(t *testing.T) {
|
|
||||||
bpool := NewBufioReadersPool(16)
|
|
||||||
r := bufio.NewReader(strings.NewReader("mochi"))
|
|
||||||
out := make([]byte, 5)
|
|
||||||
buf := bpool.Get(r)
|
|
||||||
n, err := buf.Read(out)
|
|
||||||
require.Equal(t, 5, n)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, out)
|
|
||||||
|
|
||||||
bpool.Put(buf)
|
|
||||||
require.Panics(t, func() {
|
|
||||||
buf.ReadByte()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBufioReadersGet(b *testing.B) {
|
|
||||||
bpool := NewBufioReadersPool(16)
|
|
||||||
r := bufio.NewReader(strings.NewReader("mochi"))
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Get(r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBufioReadersPut(b *testing.B) {
|
|
||||||
bpool := NewBufioReadersPool(16)
|
|
||||||
r := bufio.NewReader(strings.NewReader("mochi"))
|
|
||||||
buf := bpool.Get(r)
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Put(buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBufioWritersPool(t *testing.T) {
|
|
||||||
bpool := NewBufioWritersPool(16)
|
|
||||||
require.NotNil(t, bpool.pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkNewBufioWritersPool(b *testing.B) {
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
NewBufioWritersPool(16)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBufioWritersGetPut(t *testing.T) {
|
|
||||||
payload := []byte{'m', 'o', 'c', 'h', 'i'}
|
|
||||||
|
|
||||||
bpool := NewBufioWritersPool(16)
|
|
||||||
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
bw := bufio.NewWriter(buf)
|
|
||||||
require.NotNil(t, bw)
|
|
||||||
w := bpool.Get(bw)
|
|
||||||
require.NotNil(t, w)
|
|
||||||
|
|
||||||
n, err := w.Write(payload)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 5, n)
|
|
||||||
|
|
||||||
w.Flush()
|
|
||||||
bw.Flush()
|
|
||||||
require.Equal(t, payload, buf.Bytes())
|
|
||||||
|
|
||||||
bpool.Put(w)
|
|
||||||
_, err = w.Write(payload)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Panics(t, func() {
|
|
||||||
w.Flush()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBufioWritersGet(b *testing.B) {
|
|
||||||
bpool := NewBufioWritersPool(16)
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
bw := bufio.NewWriter(buf)
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Get(bw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBufioWritersPut(b *testing.B) {
|
|
||||||
bpool := NewBufioWritersPool(16)
|
|
||||||
w := bpool.Get(bufio.NewWriter(new(bytes.Buffer)))
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Put(w)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
@@ -1,33 +0,0 @@
|
|||||||
package pools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BytesBuffersPool is a pool of bytes.Buffer.
|
|
||||||
type BytesBuffersPool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBytesBuffersPool returns a sync.pool of bytes.Buffer.
|
|
||||||
func NewBytesBuffersPool() BytesBuffersPool {
|
|
||||||
return BytesBuffersPool{
|
|
||||||
pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
return new(bytes.Buffer)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns a pooled bytes.Buffer.
|
|
||||||
func (b BytesBuffersPool) Get() *bytes.Buffer {
|
|
||||||
return b.pool.Get().(*bytes.Buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts the bytes.Buffer back into the pool.
|
|
||||||
func (b BytesBuffersPool) Put(x *bytes.Buffer) {
|
|
||||||
x.Reset()
|
|
||||||
b.pool.Put(x)
|
|
||||||
}
|
|
@@ -1,52 +0,0 @@
|
|||||||
package pools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewBytesBuffersPool(t *testing.T) {
|
|
||||||
bpool := NewBytesBuffersPool()
|
|
||||||
require.NotNil(t, bpool.pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkNewBytesBuffersPool(b *testing.B) {
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
NewBytesBuffersPool()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBytesBuffersPoolGet(t *testing.T) {
|
|
||||||
bpool := NewBytesBuffersPool()
|
|
||||||
buf := bpool.Get()
|
|
||||||
|
|
||||||
buf.WriteString("mochi")
|
|
||||||
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, buf.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBytesBuffersPoolGet(b *testing.B) {
|
|
||||||
bpool := NewBytesBuffersPool()
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Get()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBytesBuffersPoolPut(t *testing.T) {
|
|
||||||
bpool := NewBytesBuffersPool()
|
|
||||||
buf := bpool.Get()
|
|
||||||
|
|
||||||
buf.WriteString("mochi")
|
|
||||||
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, buf.Bytes())
|
|
||||||
|
|
||||||
bpool.Put(buf)
|
|
||||||
require.Equal(t, 0, buf.Len())
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBytesBuffersPoolPut(b *testing.B) {
|
|
||||||
bpool := NewBytesBuffersPool()
|
|
||||||
buf := bpool.Get()
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
bpool.Put(buf)
|
|
||||||
}
|
|
||||||
}
|
|
60
processor.go
60
processor.go
@@ -2,12 +2,13 @@ package mqtt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/circ"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
"github.com/mochi-co/mqtt/processors/circ"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Processor reads and writes bytes to a network connection.
|
// Processor reads and writes bytes to a network connection.
|
||||||
@@ -22,6 +23,12 @@ type Processor struct {
|
|||||||
// W is a writer for writing outgoing bytes.
|
// W is a writer for writing outgoing bytes.
|
||||||
W *circ.Writer
|
W *circ.Writer
|
||||||
|
|
||||||
|
// started tracks the goroutines which have been started.
|
||||||
|
started sync.WaitGroup
|
||||||
|
|
||||||
|
// ended tracks the goroutines which have ended.
|
||||||
|
ended sync.WaitGroup
|
||||||
|
|
||||||
// FixedHeader is the FixedHeader from the last read packet.
|
// FixedHeader is the FixedHeader from the last read packet.
|
||||||
FixedHeader packets.FixedHeader
|
FixedHeader packets.FixedHeader
|
||||||
}
|
}
|
||||||
@@ -33,6 +40,51 @@ func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
|
|||||||
R: r,
|
R: r,
|
||||||
W: w,
|
W: w,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start spins up the reader and writer goroutines.
|
||||||
|
func (p *Processor) Start() {
|
||||||
|
go func() {
|
||||||
|
defer p.ended.Done()
|
||||||
|
fmt.Println("starting readFrom", p.Conn)
|
||||||
|
p.started.Done()
|
||||||
|
n, err := p.R.ReadFrom(p.Conn)
|
||||||
|
if err != nil {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
fmt.Println(">>> finished ReadFrom", n, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer p.ended.Done()
|
||||||
|
fmt.Println("starting writeTo", p.Conn)
|
||||||
|
p.started.Done()
|
||||||
|
n, err := p.W.WriteTo(p.Conn)
|
||||||
|
if err != nil {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(">>> finished WriteTo", n, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
p.started.Add(2)
|
||||||
|
p.ended.Add(2)
|
||||||
|
p.started.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the processor goroutines.
|
||||||
|
func (p *Processor) Stop() {
|
||||||
|
fmt.Println("processor stop")
|
||||||
|
|
||||||
|
p.W.Close()
|
||||||
|
if p.Conn != nil {
|
||||||
|
p.Conn.Close()
|
||||||
|
}
|
||||||
|
p.R.Close()
|
||||||
|
|
||||||
|
p.ended.Wait()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||||
@@ -99,6 +151,8 @@ func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("FIXEDHEADER READ", *fh)
|
||||||
|
|
||||||
// Set the fixed header in the parser.
|
// Set the fixed header in the parser.
|
||||||
p.FixedHeader = *fh
|
p.FixedHeader = *fh
|
||||||
|
|
||||||
@@ -139,7 +193,7 @@ func (p *Processor) Read() (pk packets.Packet, err error) {
|
|||||||
case packets.Disconnect:
|
case packets.Disconnect:
|
||||||
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
|
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
|
||||||
default:
|
default:
|
||||||
return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
|
return pk, fmt.Errorf("No valid packet available; %v", p.FixedHeader.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
bt, err := p.R.Read(int64(p.FixedHeader.Remaining))
|
bt, err := p.R.Read(int64(p.FixedHeader.Remaining))
|
||||||
|
@@ -7,20 +7,20 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/circ"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
"github.com/mochi-co/mqtt/processors/circ"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewProcessor(t *testing.T) {
|
func TestNewProcessor(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
|
||||||
require.NotNil(t, p.R)
|
require.NotNil(t, p.R)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkNewProcessor(b *testing.B) {
|
func BenchmarkNewProcessor(b *testing.B) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
r := circ.NewReader(16)
|
r := circ.NewReader(16, 4)
|
||||||
w := circ.NewWriter(16)
|
w := circ.NewWriter(16, 4)
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
NewProcessor(conn, r, w)
|
NewProcessor(conn, r, w)
|
||||||
@@ -29,7 +29,7 @@ func BenchmarkNewProcessor(b *testing.B) {
|
|||||||
|
|
||||||
func TestProcessorRefreshDeadline(t *testing.T) {
|
func TestProcessorRefreshDeadline(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
|
||||||
|
|
||||||
dl := p.Conn.(*MockNetConn).Deadline
|
dl := p.Conn.(*MockNetConn).Deadline
|
||||||
p.RefreshDeadline(10)
|
p.RefreshDeadline(10)
|
||||||
@@ -39,21 +39,17 @@ func TestProcessorRefreshDeadline(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkProcessorRefreshDeadline(b *testing.B) {
|
func BenchmarkProcessorRefreshDeadline(b *testing.B) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
p.RefreshDeadline(10)
|
p.RefreshDeadline(10)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessorStart(t *testing.T) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProcessorReadFixedHeader(t *testing.T) {
|
func TestProcessorReadFixedHeader(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
|
|
||||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
|
||||||
|
|
||||||
// Test null data.
|
// Test null data.
|
||||||
fh := new(packets.FixedHeader)
|
fh := new(packets.FixedHeader)
|
||||||
@@ -82,7 +78,7 @@ func TestProcessorRead(t *testing.T) {
|
|||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
|
|
||||||
var fh packets.FixedHeader
|
var fh packets.FixedHeader
|
||||||
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
|
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
|
||||||
p.R.Set([]byte{
|
p.R.Set([]byte{
|
||||||
byte(packets.Publish << 4), 18, // Fixed header
|
byte(packets.Publish << 4), 18, // Fixed header
|
||||||
0, 5, // Topic Name - LSB+MSB
|
0, 5, // Topic Name - LSB+MSB
|
||||||
@@ -108,7 +104,7 @@ func TestProcessorRead(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessorReadFail(t *testing.T) {
|
func TestProcessorReadFail(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
|
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
|
||||||
p.R.Set([]byte{
|
p.R.Set([]byte{
|
||||||
byte(packets.Publish << 4), 3, // Fixed header
|
byte(packets.Publish << 4), 3, // Fixed header
|
||||||
0, 5, // Topic Name - LSB+MSB
|
0, 5, // Topic Name - LSB+MSB
|
||||||
@@ -143,7 +139,7 @@ func TestProcessorReadPacketNoOverwrite(t *testing.T) {
|
|||||||
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
|
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
|
||||||
}
|
}
|
||||||
|
|
||||||
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
|
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
|
||||||
p.R.Set(pk1, 0, len(pk1))
|
p.R.Set(pk1, 0, len(pk1))
|
||||||
p.R.SetPos(0, int64(len(pk1)))
|
p.R.SetPos(0, int64(len(pk1)))
|
||||||
var fh packets.FixedHeader
|
var fh packets.FixedHeader
|
||||||
@@ -168,7 +164,7 @@ func TestProcessorReadPacketNoOverwrite(t *testing.T) {
|
|||||||
func TestProcessorReadPacketNil(t *testing.T) {
|
func TestProcessorReadPacketNil(t *testing.T) {
|
||||||
|
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
|
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
|
||||||
var fh packets.FixedHeader
|
var fh packets.FixedHeader
|
||||||
|
|
||||||
// Check for un-specified packet.
|
// Check for un-specified packet.
|
||||||
@@ -188,7 +184,7 @@ func TestProcessorReadPacketNil(t *testing.T) {
|
|||||||
|
|
||||||
func TestProcessorReadPacketReadOverflow(t *testing.T) {
|
func TestProcessorReadPacketReadOverflow(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
|
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
|
||||||
var fh packets.FixedHeader
|
var fh packets.FixedHeader
|
||||||
|
|
||||||
// Check for un-specified packet.
|
// Check for un-specified packet.
|
||||||
|
Reference in New Issue
Block a user