Refactor, fixes to Circ Peek, start processor/parser

This commit is contained in:
Mochi
2019-10-23 21:05:37 +01:00
parent e0a00feb6b
commit 7913e93fcf
13 changed files with 244 additions and 417 deletions

18
mqtt.go
View File

@@ -12,6 +12,7 @@ import (
"github.com/mochi-co/mqtt/auth"
"github.com/mochi-co/mqtt/listeners"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
"github.com/mochi-co/mqtt/topics"
"github.com/mochi-co/mqtt/topics/trie"
)
@@ -90,7 +91,22 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
// Serve begins the event loops for establishing client connections on all
// attached listeners.
func (s *Server) Serve() error {
s.listeners.ServeAll(s.EstablishConnection)
s.listeners.ServeAll(s.EstablishConnection2)
return nil
}
// EstablishConnection establishes a new client connection with the broker.
func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller) error {
p := NewProcessor(c, circ.NewReader(128), circ.NewWriter(128))
// Pull the header from the first packet and check for a CONNECT message.
fh := new(packets.FixedHeader)
err := p.ReadFixedHeader(fh)
if err != nil {
return ErrReadConnectFixedHeader
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"io"
"log"
"net"
"sync"
"time"
@@ -67,6 +68,8 @@ func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
log.Println("Peeked", peeked)
// Unpack message type and flags from byte 1.
err = fh.Decode(peeked[0])
if err != nil {

View File

@@ -194,7 +194,7 @@ var fixedHeaderExpected = []fixedHeaderTable{
}
*/
func TestReadFixedHeader(t *testing.T) {
func TestParserReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
// Test null data.

108
processor.go Normal file
View File

@@ -0,0 +1,108 @@
package mqtt
import (
"encoding/binary"
"log"
"net"
"time"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
)
// Processor reads and writes bytes to a network connection.
type Processor struct {
// Conn is the net.Conn used to establish the connection.
Conn net.Conn
// R is a reader for reading incoming bytes.
R *circ.Reader
// W is a writer for writing outgoing bytes.
W *circ.Writer
// FixedHeader is the FixedHeader from the last read packet.
FixedHeader packets.FixedHeader
}
// NewProcessor returns a new instance of Processor.
func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
return &Processor{
Conn: c,
R: r,
W: w,
}
}
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (p *Processor) RefreshDeadline(keepalive uint16) {
if p.Conn != nil {
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
p.Conn.SetDeadline(time.Now().Add(expiry))
}
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
// Peek the maximum message type and flags, and length.
peeked, err := p.R.Peek(1)
if err != nil {
return err
}
log.Println("Peeked", peeked)
// Unpack message type and flags from byte 1.
err = fh.Decode(peeked[0])
if err != nil {
// @SPEC [MQTT-2.2.2-2]
// If invalid flags are received, the receiver MUST close the Network Connection.
return packets.ErrInvalidFlags
}
log.Printf("FH %+v\n", fh)
// The remaining length value can be up to 5 bytes. Peek through each byte
// looking for continue values, and if found increase the peek. Otherwise
// decode the bytes that were legit.
//p.fhBuffer = p.fhBuffer[:0]
buf := make([]byte, 0, 6)
i := 1
//var b int64 = 2
for b := int64(2); b < 6; b++ {
peeked, err = p.R.Peek(b)
if err != nil {
return err
}
// Add the byte to the length bytes slice.
buf = append(buf, peeked[i])
// If it's not a continuation flag, end here.
if peeked[i] < 128 {
break
}
// If i has reached 4 without a length terminator, throw a protocol violation.
i++
if i == 4 {
return packets.ErrOversizedLengthIndicator
}
}
// Calculate and store the remaining length.
rem, _ := binary.Uvarint(buf)
fh.Remaining = int(rem)
// Discard the number of used length bytes + first byte.
//p.R.Discard(b)
// Set the fixed header in the parser.
p.FixedHeader = *fh
return nil
}

44
processor_test.go Normal file
View File

@@ -0,0 +1,44 @@
package mqtt
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
)
func TestNewProcessor(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
require.NotNil(t, p.R)
}
func TestProcessorReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
// Test null data.
fh := new(packets.FixedHeader)
//p.R.SetPos(0, 1)
fmt.Printf("A %+v\n", p.R)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
// Test insufficient peeking.
fh = new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4}, 0, 1)
p.R.SetPos(0, 1)
fmt.Printf("B %+v\n", p.R)
err = p.ReadFixedHeader(fh)
require.Error(t, err)
fh = new(packets.FixedHeader)
p.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
p.R.SetPos(0, 2)
err = p.ReadFixedHeader(fh)
require.NoError(t, err)
}

View File

@@ -51,6 +51,22 @@ func newBuffer(size int64) buffer {
}
}
// Set writes bytes to a byte buffer. This method should only be used for testing
// and will panic if out of range.
func (b *buffer) Set(p []byte, start, end int) {
o := 0
for i := start; i < end; i++ {
b.buf[i] = p[o]
o++
}
}
// SetPos sets the head and tail of the buffer. This method should only be
// used for testing.
func (b *buffer) SetPos(tail, head int64) {
b.tail, b.head = tail, head
}
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
// until there are at least n bytes to write without overrunning the tail.
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {

View File

@@ -23,6 +23,17 @@ func TestNewBuffer(t *testing.T) {
require.Equal(t, size, buf.size)
}
func TestSet(t *testing.T) {
buf := newBuffer(8)
require.Equal(t, make([]byte, 4), buf.buf[0:4])
p := []byte{'1', '2', '3', '4'}
buf.Set(p, 0, 4)
require.Equal(t, p, buf.buf[0:4])
buf.Set(p, 2, 6)
require.Equal(t, p, buf.buf[2:6])
}
func TestAwaitCapacity(t *testing.T) {
tests := []struct {
tail int64

View File

@@ -70,6 +70,40 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
tail := atomic.LoadInt64(&b.tail)
head := atomic.LoadInt64(&b.head)
// Wait until there's at least 1 byte of data.
/*
b.wcond.L.Lock()
for ; head == tail; head = atomic.LoadInt64(&b.head) {
if atomic.LoadInt64(&b.done) == 1 {
return nil, io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
*/
// Figure out if we can get all n bytes.
if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
return nil, ErrInsufficientBytes
}
// If we've wrapped, get the bytes from the next and start.
if head < tail {
b.tmp = b.buf[tail:b.size]
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
return b.tmp, nil
} else {
return b.buf[tail : tail+n], nil
}
return nil, nil
}
// PeekWait waits until their are n bytes available to return.
func (b *Reader) PeekWait(n int64) ([]byte, error) {
tail := atomic.LoadInt64(&b.tail)
head := atomic.LoadInt64(&b.head)
// Wait until there's at least 1 byte of data.
b.wcond.L.Lock()
for ; head == tail; head = atomic.LoadInt64(&b.head) {

View File

@@ -54,7 +54,17 @@ DONE:
return
}
// Write writes the buffer to the buffer p.
func (b *Writer) Write(p []byte) (nn int, err error) {
if atomic.LoadInt64(&b.done) == 1 {
return 0, io.EOF
}
// Wait until there's
_, err = b.awaitCapacity(int64(len(p)))
if err != nil {
return
}
return
}

View File

@@ -1,199 +0,0 @@
package ring
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
)
var (
blockSize int64 = 4
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
)
// _______________________________________________________________________
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
// -----------------------------------------------------------------------
// Buffer is a ring-style byte buffer.
type Buffer struct {
// size is the size of the buffer.
size int64
// buffer is the circular byte buffer.
buffer []byte
// tmp is a temporary buffer.
tmp []byte
// head is the current position in the sequence.
head int64
// tail is the committed position in the sequence, I.E. where we
// have successfully consumed and processed to.
tail int64
// rcond is the sync condition for the reader.
rcond *sync.Cond
// done indicates that the buffer is closed.
done int64
// wcond is the sync condition for the writer.
wcond *sync.Cond
}
// NewBuffer returns a pointer to a new instance of Buffer.
func NewBuffer(size int64) *Buffer {
return &Buffer{
size: size,
buffer: make([]byte, size),
tmp: make([]byte, size),
rcond: sync.NewCond(new(sync.Mutex)),
wcond: sync.NewCond(new(sync.Mutex)),
}
}
// awaitCapacity will hold until there are n bytes free in the buffer.
// awaitCapacity will block until there are at least n bytes for the head to
// write into without overrunning the tail.
func (b *Buffer) awaitCapacity(n int64) (int64, error) {
tail := atomic.LoadInt64(&b.tail)
head := atomic.LoadInt64(&b.head)
next := head + n
wrapped := next - b.size
b.rcond.L.Lock()
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
b.rcond.Wait()
if atomic.LoadInt64(&b.done) == 1 {
fmt.Println("ending")
return 0, io.EOF
}
}
b.rcond.L.Unlock()
return head, nil
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Buffer) ReadFrom(r io.Reader) (total int64, err error) {
DONE:
for {
if atomic.LoadInt64(&b.done) == 1 {
err = io.EOF
break DONE
//return 0, io.EOF
}
// Wait until there's enough capacity in the buffer.
st, err := b.awaitCapacity(blockSize)
if err != nil {
break DONE
//return 0, err
}
// Find the start and end indexes within the buffer.
start := st & (b.size - 1)
end := start + blockSize
// If we've overrun, just fill up to the end and then collect
// the rest on the next pass.
if end > b.size {
end = b.size
}
// Read into the buffer between the start and end indexes only.
n, err := r.Read(b.buffer[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
break DONE
//return int64(n), err
}
// Move the head forward.
//fmt.Println(start, end, b.buffer[start:end])
// fmt.Println(">>", b.head, start+int64(n), end)
atomic.StoreInt64(&b.head, end) //start+int64(n))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}
return
}
// WriteTo writes the contents of the buffer to an io.Writer.
func (b *Buffer) WriteTo(w io.Writer) (total int64, err error) {
var p []byte
var n int
DONE:
for {
if atomic.LoadInt64(&b.done) == 1 {
err = io.EOF
break DONE
}
// Peek until there's bytes to write.
p, err = b.Peek(blockSize)
if err != nil {
break DONE
}
n, err = w.Write(p)
total += int64(n)
if err != nil {
break DONE
}
// Move the tail forward the bytes written.
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
atomic.StoreInt64(&b.tail, end)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
}
return
}
// Peek returns the next n bytes without advancing the reader.
func (b *Buffer) Peek(n int64) ([]byte, error) {
tail := atomic.LoadInt64(&b.tail)
head := atomic.LoadInt64(&b.head)
// Wait until there's at least 1 byte of data.
b.wcond.L.Lock()
for ; head == tail; head = atomic.LoadInt64(&b.head) {
if atomic.LoadInt64(&b.done) == 1 {
return nil, io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
// Figure out if we can get all n bytes.
start := tail
end := tail + n
if head > tail && tail+n > head || head < tail && b.size-tail+head < n {
return nil, ErrInsufficientBytes
}
// If we've wrapped, get the bytes from the end and start.
if head < tail {
b.tmp = b.buffer[start:b.size]
b.tmp = append(b.tmp, b.buffer[:(end%b.size)]...)
return b.tmp, nil
} else {
return b.buffer[start:end], nil
}
return nil, nil
}

View File

@@ -1,214 +0,0 @@
package ring
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func init() {
blockSize = 4
}
func TestNewBuffer(t *testing.T) {
var size int64 = 256
buf := NewBuffer(size)
require.NotNil(t, buf.buffer)
require.Equal(t, size, int64(len(buf.buffer)))
require.Equal(t, size, buf.size)
}
func TestAwaitCapacity(t *testing.T) {
tt := []struct {
tail int64
head int64
need int64
await int
desc string
}{
{0, 0, 4, 0, "OK 4, 0"},
{6, 6, 4, 0, "OK 6, 10"},
{12, 16, 4, 0, "OK 16, 4"},
{16, 12, 4, 0, "OK 16 16"},
{3, 6, 4, 0, "OK 3, 10"},
{12, 0, 4, 0, "OK 16, 0"},
{1, 16, 4, 4, "next is more than tail, wait for tail incr"},
{7, 5, 4, 2, "tail is great than head, wrapped and caught up with tail, wait for tail incr"},
}
buf := NewBuffer(16)
for i, check := range tt {
buf.tail, buf.head = check.tail, check.head
o := make(chan []interface{})
var start int64 = -1
var err error
go func() {
start, err = buf.awaitCapacity(4)
o <- []interface{}{start, err}
}()
time.Sleep(time.Millisecond) // atomic updates are super fast so wait a bit
for i := 0; i < check.await; i++ {
atomic.AddInt64(&buf.tail, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
}
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
if start == -1 {
atomic.StoreInt64(&buf.done, 1)
buf.rcond.L.Lock()
buf.rcond.Broadcast()
buf.rcond.L.Unlock()
}
done := <-o
require.Equal(t, check.head, done[0].(int64), "Head-Start mismatch [i:%d] %s", i, check.desc)
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, check.desc)
}
}
func TestReadFrom(t *testing.T) {
buf := NewBuffer(16)
b4 := bytes.Repeat([]byte{'-'}, 4)
br := bytes.NewReader(b4)
_, err := buf.ReadFrom(br)
require.NoError(t, err)
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buffer[:4])
require.Equal(t, int64(4), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.Equal(t, int64(8), buf.head)
br.Reset(b4)
_, err = buf.ReadFrom(br)
require.Equal(t, int64(12), buf.head)
br.Reset([]byte{'-', '-', '-', '-', '/', '/', '/', '/'})
o := make(chan error)
go func() {
_, err := buf.ReadFrom(br)
o <- err
}()
atomic.StoreInt64(&buf.tail, 4)
<-o
require.Equal(t, int64(4), buf.head)
require.Equal(t, []byte{'/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'}, buf.buffer)
}
func TestPeek(t *testing.T) {
buf := NewBuffer(16)
buf.buffer = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
tests := []struct {
tail int64
head int64
want int
bytes []byte
err error
desc string
}{
{tail: 0, head: 4, want: 4, bytes: []byte{'a', 'b', 'c', 'd'}, err: nil, desc: "0,4: OK"},
{tail: 3, head: 15, want: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, err: nil, desc: "0,4: OK"},
{tail: 14, head: 5, want: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, err: nil, desc: "14,5 OK"},
{tail: 14, head: 1, want: 6, bytes: nil, err: ErrInsufficientBytes, desc: "14,1 insufficient bytes"},
}
for i, tt := range tests {
buf.tail, buf.head = tt.tail, tt.head
o := make(chan []interface{})
go func() {
bs, err := buf.Peek(int64(tt.want))
o <- []interface{}{bs, err}
}()
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.head, buf.head+int64(tt.want))
time.Sleep(time.Millisecond * 10)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
done := <-o
if tt.err != nil {
require.Error(t, done[1].(error), "Expected Error [i:%d] %s", i, tt.desc)
} else {
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
}
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
}
fmt.Println("done")
}
func TestWriteTo(t *testing.T) {
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
tests := []struct {
tail int64
next int64
head int64
err error
bytes []byte
total int64
desc string
}{
{tail: 0, next: 4, head: 4, err: io.EOF, bytes: []byte{'a', 'b', 'c', 'd'}, total: 4, desc: "0, 4 OK"},
{tail: 14, next: 2, head: 2, err: io.EOF, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "14, 2 wrap"},
{tail: 0, next: 0, head: 2, err: ErrInsufficientBytes, bytes: []byte{}, total: 0, desc: "0, 2 insufficient"},
{tail: 14, next: 2, head: 3, err: ErrInsufficientBytes, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "0, 3 OK > insufficient wrap"},
}
for i, tt := range tests {
buf := NewBuffer(16)
buf.buffer = bb
buf.tail, buf.head = tt.tail, tt.head
r, w := net.Pipe()
go func() {
time.Sleep(time.Millisecond)
atomic.StoreInt64(&buf.done, 1)
buf.wcond.L.Lock()
buf.wcond.Broadcast()
buf.wcond.L.Unlock()
w.Close()
}()
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(r)
if err != nil {
panic(err)
}
recv <- buf
}()
total, err := buf.WriteTo(w)
require.Equal(t, tt.next, buf.tail, "Tail placement mismatched [i:%d] %s", i, tt.desc)
require.Equal(t, tt.total, total, "Total bytes written mismatch [i:%d] %s", i, tt.desc)
if tt.err != nil {
require.Error(t, err, "Expected Error [i:%d] %s", i, tt.desc)
} else {
require.NoError(t, err, "Unexpected Error [i:%d] %s", i, tt.desc)
}
require.Equal(t, tt.bytes, <-recv, "Written bytes mismatch [i:%d] %s", i, tt.desc)
r.Close()
}
}

View File

@@ -1 +0,0 @@
package ring

View File

@@ -1 +0,0 @@
package ring