Debug func, better closers

This commit is contained in:
Mochi
2019-10-27 11:00:50 +00:00
parent f3310b0522
commit d8bba41da4
21 changed files with 441 additions and 1137 deletions

View File

@@ -1,25 +1,28 @@
package circ
import (
"fmt"
"io"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/debug"
)
var (
// blockSize is the default size of bytes per R/W block.
blockSize int64 = 8192
// bufferSize is the default size of the buffer.
bufferSize int64 = 1024 * 256
// blockSize is the default size of bytes per R/W block.
blockSize int64 = 2048
)
// buffer contains core values and methods to be included in a reader or writer.
type buffer struct {
sync.RWMutex
// id is the identifier of the buffer. This is used in debug output.
id string
// size is the size of the buffer.
size int64
@@ -73,11 +76,32 @@ func newBuffer(size, block int64) buffer {
}
}
// Close notifies the buffer event loops to stop.
func (b *buffer) Close() {
atomic.StoreInt64(&b.done, 1)
debug.Println(b.id, "[B] STORE done=1 buffer closing")
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
debug.Println(b.id, "[B] DONE REBROADCASTED")
}
// Get will return the tail and head positions of the buffer.
func (b *buffer) GetPos() (int64, int64) {
return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
}
// SetPos sets the head and tail of the buffer. This method should only be
// used for testing.
func (b *buffer) SetPos(tail, head int64) {
b.tail, b.head = tail, head
}
// Set writes bytes to a byte buffer. This method should only be used for testing
// and will panic if out of range.
func (b *buffer) Set(p []byte, start, end int) {
@@ -88,12 +112,6 @@ func (b *buffer) Set(p []byte, start, end int) {
}
}
// SetPos sets the head and tail of the buffer. This method should only be
// used for testing.
func (b *buffer) SetPos(tail, head int64) {
b.tail, b.head = tail, head
}
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
// until there are at least n bytes to write without overrunning the tail.
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
@@ -102,16 +120,21 @@ func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
wrapped := next - b.size
tail := atomic.LoadInt64(&b.tail)
debug.Println(b.id, "[B] awaiting capacity (n)", n)
b.rcond.L.Lock()
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
debug.Println(b.id, "[B] iter no capacity")
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
if atomic.LoadInt64(&b.done) == 1 {
fmt.Println("ending")
return 0, io.EOF
}
debug.Println(b.id, "[B] iter no capacity waiting")
b.rcond.Wait()
}
b.rcond.L.Unlock()
debug.Println(b.id, "[B] capacity unlocked (tail)", tail)
return
}
@@ -122,15 +145,20 @@ func (b *buffer) awaitFilled(n int64) (tail int64, err error) {
head := atomic.LoadInt64(&b.head)
tail = atomic.LoadInt64(&b.tail)
debug.Println(b.id, "[B] awaiting filled (tail, head)", tail, head)
b.wcond.L.Lock()
for ; head > tail && tail+n > head || head < tail && b.size-tail+head < n; head = atomic.LoadInt64(&b.head) {
debug.Println(b.id, "[B] iter no fill")
if atomic.LoadInt64(&b.done) == 1 {
fmt.Println("ending")
return 0, io.EOF
}
debug.Println(b.id, "[B] iter no fill waiting")
b.wcond.Wait()
}
b.wcond.L.Unlock()
debug.Println(b.id, "[B] filled (head)", head)
return
}
@@ -144,16 +172,20 @@ func (b *buffer) awaitFilled(n int64) (tail int64, err error) {
// CommitTail moves the tail position of the buffer n bytes, and will wait until
// there is enough capacity for at least n bytes.
func (b *buffer) CommitTail(n int64) error {
debug.Println(b.id, "[B] commit/discard (n)", n)
_, err := b.awaitFilled(n)
if err != nil {
return err
}
debug.Println(b.id, "[B] commit/discard unlocked")
tail := atomic.LoadInt64(&b.tail)
if tail+n < b.size {
atomic.StoreInt64(&b.tail, tail+n)
debug.Println(b.id, "[B] STORE TAIL", tail+n)
} else {
atomic.StoreInt64(&b.tail, (tail+n)%b.size)
debug.Println(b.id, "[B] STORE TAIL (wrapped)", (tail+n)%b.size)
}
b.rcond.L.Lock()
@@ -162,3 +194,14 @@ func (b *buffer) CommitTail(n int64) error {
return nil
}
// capDelta returns the difference between the head and tail.
func (b *buffer) capDelta(tail, head int64) int64 {
if head < tail {
debug.Println(b.id, "[B] capdelta (tail, head, v)", tail, head, b.size-tail+head)
return b.size - tail + head
}
debug.Println(b.id, "[B] capdelta (tail, head, v)", tail, head, head-tail)
return head - tail
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"io"
"sync/atomic"
"github.com/mochi-co/mqtt/debug"
)
var (
@@ -18,25 +20,31 @@ type Reader struct {
// NewReader returns a pointer to a new Circular Reader.
func NewReader(size, block int64) *Reader {
b := newBuffer(size, block)
b.id = "\treader"
return &Reader{
newBuffer(size, block),
b,
}
}
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
// there is sufficient capacity to do so.
func (b *Reader) ReadFrom(r io.Reader) (total int64, err error) {
debug.Println(b.id, "*[R] STARTING SPIN")
DONE:
for {
debug.Println(b.id, "*[R] SPIN")
if atomic.LoadInt64(&b.done) == 1 {
err = io.EOF
debug.Println(b.id, "*[R] readFrom DONE")
break DONE
}
fmt.Println("readingfrom")
// Wait until there's enough capacity in the buffer.
st, err := b.awaitCapacity(b.block)
if err != nil {
break DONE
debug.Println(b.id, "*[R] readFrom await capacity error, ", err)
continue
}
// Find the start and end indexes within the buffer.
@@ -50,21 +58,23 @@ DONE:
}
// Read into the buffer between the start and end indexes only.
debug.Println(b.id, "*[R] readFrom mapping (start, end)", start, end)
n, err := r.Read(b.buf[start:end])
total += int64(n) // incr total bytes read.
if err != nil {
break DONE
}
fmt.Println("<", b.buf[start:end])
// Move the head forward.
atomic.StoreInt64(&b.head, end)
debug.Println(b.id, "*[R] STORE HEAD", start+int64(n))
atomic.StoreInt64(&b.head, start+int64(n))
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
}
debug.Println(b.id, "*[R] FINISHED SPIN")
return
}
@@ -74,7 +84,7 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
head := atomic.LoadInt64(&b.head)
// Wait until there's at least 1 byte of data.
/*
debug.Println(b.id, "[R] PEEKING (tail, head, n)", tail, head, n)
b.wcond.L.Lock()
for ; head == tail; head = atomic.LoadInt64(&b.head) {
if atomic.LoadInt64(&b.done) == 1 {
@@ -83,55 +93,38 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
b.wcond.Wait()
}
b.wcond.L.Unlock()
*/
debug.Println(b.id, "[R] PEEKING available")
// Figure out if we can get all n bytes.
if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
//if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
if head == tail || head < tail && b.size-tail+head < n {
debug.Println(b.id, "[R] insufficient bytes (tail, next, head)", tail, tail+n, head)
return nil, ErrInsufficientBytes
}
var d int64
if head < tail {
d = b.size - tail + head
} else {
d = head - tail
}
// Peek maximum of n bytes.
if d > n {
d = n
}
// If we've wrapped, get the bytes from the next and start.
if head < tail {
b.tmp = b.buf[tail:b.size]
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
b.tmp = append(b.tmp, b.buf[:(tail+d)%b.size]...)
debug.Println(b.id, "[R] PEEKED wrapped (tail, next, buf)", tail, (tail+d)%b.size, b.tmp)
return b.tmp, nil
} else {
return b.buf[tail : tail+n], nil
}
return nil, nil
}
debug.Println(b.id, "[R] PEEKED (tail, next, buf)", tail, tail+d, b.buf[tail:tail+d])
return b.buf[tail : tail+d], nil
// PeekWait waits until their are n bytes available to return.
func (b *Reader) PeekWait(n int64) ([]byte, error) {
tail := atomic.LoadInt64(&b.tail)
head := atomic.LoadInt64(&b.head)
// Wait until there's at least 1 byte of data.
b.wcond.L.Lock()
for ; head == tail; head = atomic.LoadInt64(&b.head) {
if atomic.LoadInt64(&b.done) == 1 {
return nil, io.EOF
}
b.wcond.Wait()
}
b.wcond.L.Unlock()
// Figure out if we can get all n bytes.
if head > tail && tail+n > head || head < tail && b.size-tail+head < n {
return nil, ErrInsufficientBytes
}
// If we've wrapped, get the bytes from the next and start.
if head < tail {
b.tmp = b.buf[tail:b.size]
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
return b.tmp, nil
} else {
return b.buf[tail : tail+n], nil
}
return nil, nil
}
// Read reads the next n bytes from the buffer. If n bytes are not
@@ -143,19 +136,31 @@ func (b *Reader) Read(n int64) (p []byte, err error) {
}
// Wait until there's at least len(p) bytes to read.
debug.Println(b.id, "[R] awaiting read fill, (n)", n)
tail, err := b.awaitFilled(n)
if err != nil {
debug.Println(b.id, "[R] filled err (tail, err)", tail, err)
return
}
debug.Println(b.id, "[R] filled (tail, next)", tail, tail+n)
// Once we have capacity, determine if the capacity wraps, and write it into
// the buffer p.
if atomic.LoadInt64(&b.head) < tail {
b.tmp = b.buf[tail:b.size]
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
atomic.StoreInt64(&b.tail, (tail+n)%b.size)
debug.Println(b.id, "[R] READ wrapped", b.tmp)
debug.Println(b.id, "[R] STORE TAIL wrapped", (tail+n)%b.size)
return b.tmp, nil
} else if tail+n < b.size {
atomic.StoreInt64(&b.tail, tail+n)
debug.Println(b.id, "[R] READ", b.buf[tail:tail+n])
debug.Println(b.id, "[R] STORE TAIL", tail+n)
return b.buf[tail : tail+n], nil
} else {
return p, fmt.Errorf("next byte extended size")
}
return

View File

@@ -2,7 +2,6 @@ package circ
import (
"bytes"
"fmt"
"sync/atomic"
"testing"
"time"
@@ -131,8 +130,6 @@ func TestRead(t *testing.T) {
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
done := <-o
fmt.Println(done)
fmt.Println(string(done[0].([]byte)))
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
}

View File

@@ -1,9 +1,10 @@
package circ
import (
"fmt"
"io"
"sync/atomic"
"github.com/mochi-co/mqtt/debug"
)
// Writer is a circular buffer for writing data to an io.Writer.
@@ -13,8 +14,10 @@ type Writer struct {
// NewWriter returns a pointer to a new Circular Writer.
func NewWriter(size, block int64) *Writer {
b := newBuffer(size, block)
b.id = "writer"
return &Writer{
newBuffer(size, block),
b,
}
}
@@ -22,36 +25,45 @@ func NewWriter(size, block int64) *Writer {
func (b *Writer) WriteTo(w io.Writer) (total int64, err error) {
var p []byte
var n int
DONE:
for {
cd := b.capDelta(atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head))
debug.Println(b.id, "SPIN (tail, head, delta)", b.tail, b.head, cd)
if atomic.LoadInt64(&b.done) == 1 {
if cd == 0 {
err = io.EOF
break DONE
return total, err
} else {
debug.Println("[W] //// capDelta not reached", cd)
}
}
fmt.Println("writingto")
// Peek until there's bytes to write using the Peek method
// of the Reader type.
p, err = (*Reader)(b).Peek(b.block)
if err != nil {
break DONE
debug.Println(b.id, "[W] writeTo peek err (p, err)", p, err)
continue
//break DONE
}
fmt.Println(">", p)
debug.Println(b.id, "[W] PEEKED OK (p)", p)
// Write the peeked bytes to the io.Writer.
n, err = w.Write(p)
total += int64(n)
if err != nil {
break DONE
return total, err
}
debug.Println(b.id, "[W] SENT (n, p)", n, p)
// Move the tail forward the bytes written.
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
debug.Println(b.id, "[W] STORE TAIL", end)
atomic.StoreInt64(&b.tail, end)
b.rcond.L.Lock()
b.rcond.Broadcast()
b.rcond.L.Unlock()
debug.Println(b.id, "[W] writeTo unlocked")
}
return
@@ -71,6 +83,16 @@ func (b *Writer) Write(p []byte) (nn int, err error) {
// Write the outgoing bytes to the buffer, wrapping if necessary.
nn = b.writeBytes(p)
debug.Println(b.id, "[W] bytes written (nn)", nn, p)
// Move the head forward.
next := (atomic.LoadInt64(&b.head) + int64(nn)) % b.size
atomic.StoreInt64(&b.head, next)
debug.Println(b.id, "[W] STORE HEAD", next)
b.wcond.L.Lock()
b.wcond.Broadcast()
b.wcond.L.Unlock()
return
}

View File

@@ -1,7 +1,6 @@
package circ
import (
"fmt"
"io"
"io/ioutil"
"net"
@@ -104,7 +103,6 @@ func TestWrite(t *testing.T) {
}()
done := <-o
fmt.Println(done)
require.Equal(t, tt.want, done[0].(int), "Wanted written mismatch [i:%d] %s", i, tt.desc)
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)

View File

@@ -1,6 +1,7 @@
package mqtt
import (
"fmt"
"sync"
"sync/atomic"
@@ -77,8 +78,8 @@ func (cl *clients) getByListener(id string) []*client {
type client struct {
sync.RWMutex
// p is a packets parser which reads incoming packets.
p *Parser
// p is a packets processor which reads and writes packets.
p *Processor
// ac is a pointer to an auth controller inherited from the listener.
ac auth.Controller
@@ -123,7 +124,7 @@ type client struct {
}
// newClient creates a new instance of client.
func newClient(p *Parser, pk *packets.ConnectPacket, ac auth.Controller) *client {
func newClient(p *Processor, pk *packets.ConnectPacket, ac auth.Controller) *client {
cl := &client{
p: p,
ac: ac,
@@ -192,11 +193,19 @@ func (c *client) forgetSubscription(filter string) {
// close attempts to gracefully close a client connection.
func (cl *client) close() {
cl.done.Do(func() {
close(cl.end) // Signal to stop listening for packets.
if cl.end != nil {
close(cl.end) // Signal to stop listening for packets in mqtt readClient loop.
fmt.Println("closed readClient end")
}
// Close the processor.
cl.p.Stop()
fmt.Println("signalled stop, waiting")
fmt.Println("all processors ended")
// Close the network connection.
cl.p.Conn.Close() // Error is irrelevant so can be ommitted here.
cl.p.Conn = nil
//cl.p.Conn.Close() // Error is irrelevant so can be ommitted here.
//cl.p.Conn = nil
})
}

View File

@@ -2,13 +2,13 @@ package mqtt
import (
"log"
"net"
//"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/auth"
"github.com/mochi-co/mqtt/circ"
"github.com/mochi-co/mqtt/packets"
"github.com/stretchr/testify/require"
)
func TestNewClients(t *testing.T) {
@@ -116,9 +116,7 @@ func BenchmarkClientsGetByListener(b *testing.B) {
}
func TestNewClient(t *testing.T) {
r, _ := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
pk := &packets.ConnectPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Connect,
@@ -153,9 +151,7 @@ func TestNewClient(t *testing.T) {
}
func TestNewClientLWT(t *testing.T) {
r, _ := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
pk := &packets.ConnectPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Connect,
@@ -181,11 +177,8 @@ func TestNewClientLWT(t *testing.T) {
}
func BenchmarkNewClient(b *testing.B) {
r, _ := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
pk := new(packets.ConnectPacket)
for n := 0; n < b.N; n++ {
newClient(p, pk, new(auth.Allow))
}
@@ -244,15 +237,13 @@ func BenchmarkClientForgetSubscription(b *testing.B) {
}
func TestClientClose(t *testing.T) {
r, w := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
p := NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
pk := &packets.ConnectPacket{
ClientIdentifier: "zen3",
}
client := newClient(p, pk, new(auth.Allow))
require.NotNil(t, client)
client.close()
var ok bool
@@ -261,8 +252,6 @@ func TestClientClose(t *testing.T) {
}
require.Equal(t, false, ok)
require.Nil(t, client.p.Conn)
r.Close()
w.Close()
}
func TestInFlightSet(t *testing.T) {

9
debug/log.go Normal file
View File

@@ -0,0 +1,9 @@
// +build !debug
package debug
// Println will be optimized away.
func Println(a ...interface{}) {}
// Printf will be optimized away.
func Printf(format string, a ...interface{}) {}

17
debug/log_debug.go Normal file
View File

@@ -0,0 +1,17 @@
// +build debug
package debug
import (
"fmt"
)
// Println will be optimized away.
func Println(a ...interface{}) {
fmt.Println(a...)
}
// Printf will be optimized away.
func Printf(format string, a ...interface{}) {
fmt.Printf(format, a...)
}

1
log.go Normal file
View File

@@ -0,0 +1 @@
package mqtt

129
mqtt.go
View File

@@ -1,18 +1,18 @@
package mqtt
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"time"
//"github.com/davecgh/go-spew/spew"
"github.com/mochi-co/mqtt/auth"
"github.com/mochi-co/mqtt/circ"
"github.com/mochi-co/mqtt/listeners"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
"github.com/mochi-co/mqtt/topics"
"github.com/mochi-co/mqtt/topics/trie"
)
@@ -36,9 +36,6 @@ var (
ErrConnectionClosed = errors.New("Connection not open")
ErrNoData = errors.New("No data")
ErrACLNotAuthorized = errors.New("ACL not authorized")
// rwBufSize is the size of client read/write buffers.
rwBufSize = 512
)
// Server is an MQTT broker server.
@@ -91,16 +88,19 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
// Serve begins the event loops for establishing client connections on all
// attached listeners.
func (s *Server) Serve() error {
s.listeners.ServeAll(s.EstablishConnection2)
s.listeners.ServeAll(s.EstablishConnection)
return nil
}
// EstablishConnection establishes a new client connection with the broker.
func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller) error {
var size int64 = 1024 * 256
p := NewProcessor(c, circ.NewReader(size), circ.NewWriter(size))
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
p := NewProcessor(
c,
circ.NewReader(0, 0),
circ.NewWriter(0, 0),
)
p.Start()
// Pull the header from the first packet and check for a CONNECT message.
fh := new(packets.FixedHeader)
@@ -109,9 +109,94 @@ func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller
return ErrReadConnectFixedHeader
}
return nil
// Read the first packet expecting a CONNECT message.
pk, err := p.Read()
if err != nil {
return ErrReadConnectPacket
}
// Ensure first packet is a connect packet.
msg, ok := pk.(*packets.ConnectPacket)
if !ok {
return ErrFirstPacketInvalid
}
// Ensure the packet conforms to MQTT CONNECT specifications.
retcode, _ := msg.Validate()
if retcode != packets.Accepted {
return ErrReadConnectInvalid
}
// If a username and password has been provided, perform authentication.
if msg.Username != "" && !ac.Authenticate(msg.Username, msg.Password) {
retcode = packets.CodeConnectNotAuthorised
return ErrConnectNotAuthorized
}
// Create a new client with this connection.
client := newClient(p, msg, ac)
client.listener = lid // Note the listener the client is connected to.
// If it's not a clean session and a client already exists with the same
// id, inherit the session.
var sessionPresent bool
if existing, ok := s.clients.get(msg.ClientIdentifier); ok {
existing.Lock()
existing.close()
if msg.CleanSession {
for k := range existing.subscriptions {
s.topics.Unsubscribe(k, existing.id)
}
} else {
client.inFlight = existing.inFlight // Inherit from existing session.
client.subscriptions = existing.subscriptions
sessionPresent = true
}
existing.Unlock()
}
// Add the new client to the clients manager.
s.clients.add(client)
// Send a CONNACK back to the client.
err = s.writeClient(client, &packets.ConnackPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Connack,
},
SessionPresent: sessionPresent,
ReturnCode: retcode,
})
if err != nil {
fmt.Println("WRITECLIENT err", err)
return err
}
fmt.Println("----------------")
// Resend any unacknowledged QOS messages still pending for the client.
/*err = s.resendInflight(client)
if err != nil {
return err
}*/
// Block and listen for more packets, and end if an error or nil packet occurs.
var sendLWT bool
err = s.readClient(client)
if err != nil {
sendLWT = true // Only send LWT on bad disconnect [MQTT-3.14.4-3]
}
fmt.Println("** stop and wait", err)
// Publish last will and testament.
s.closeClient(client, sendLWT)
fmt.Println("stopped")
return err // capture err or nil from readClient.
}
/*
// EstablishConnection establishes a new client connection with the broker.
func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller) error {
@@ -205,6 +290,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
return err // capture err or nil from readClient.
}
*/
// resendInflight republishes any inflight messages to the client.
func (s *Server) resendInflight(cl *client) error {
@@ -226,7 +312,7 @@ func (s *Server) readClient(cl *client) error {
var err error
var pk packets.Packet
fh := new(packets.FixedHeader)
fmt.Println("STARTING READ CLIENT LOOP")
DONE:
for {
select {
@@ -250,15 +336,20 @@ DONE:
// If it's a disconnect packet, begin the close process.
if fh.Type == packets.Disconnect {
return nil
fmt.Println(" X got disconnect")
break DONE
}
// Otherwise read in the packet payload.
pk, err = cl.p.Read()
if err != nil {
return ErrReadPacketPayload
fmt.Println("RC READ ERR", err)
//return ErrReadPacketPayload
return nil
}
fmt.Println("READ PACKET", pk)
// Validate the packet if necessary.
_, err = pk.Validate()
if err != nil {
@@ -270,6 +361,7 @@ DONE:
}
}
fmt.Println("READCLIENT COMPLETED")
return nil
}
@@ -555,8 +647,8 @@ func (s *Server) processUnsubscribe(cl *client, pk *packets.UnsubscribePacket) e
// writeClient writes packets to a client connection.
func (s *Server) writeClient(cl *client, pk packets.Packet) error {
cl.p.Lock()
defer cl.p.Unlock()
cl.p.W.Lock()
defer cl.p.W.Unlock()
// Ensure Writer is open.
if cl.p.W == nil {
@@ -575,11 +667,6 @@ func (s *Server) writeClient(cl *client, pk packets.Packet) error {
return err
}
err = cl.p.W.Flush()
if err != nil {
return err
}
// Refresh deadline to keep the connection alive.
cl.p.RefreshDeadline(cl.keepalive)

View File

@@ -2,6 +2,7 @@ package mqtt
import (
"errors"
"fmt"
"io/ioutil"
"net"
"testing"
@@ -10,11 +11,18 @@ import (
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/auth"
"github.com/mochi-co/mqtt/circ"
"github.com/mochi-co/mqtt/listeners"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/topics"
//"github.com/mochi-co/mqtt/topics"
)
// bufferSize is the buffer size of the processors.
var bufferSize int64 = 512
// blockSize is the block size of the processors.
var blockSize int64 = 128
type quietWriter struct {
b []byte
f [][]byte
@@ -45,9 +53,8 @@ func (q *quietWriter) Flush() error {
func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) {
s = New()
r, w = net.Pipe()
cl = newClient(
NewParser(r, newBufioReader(r), newBufioWriter(w)),
NewProcessor(new(MockNetConn), circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize)),
&packets.ConnectPacket{
ClientIdentifier: id,
},
@@ -178,10 +185,64 @@ func BenchmarkServerClose(b *testing.B) {
*/
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
r2, w2 := net.Pipe()
s := New()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r, new(auth.Allow))
}()
go func() {
w.Write([]byte{
byte(packets.Connect << 4), 15, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
2, // Packet Flags - clean session
0, 45, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
})
w.Write([]byte{byte(packets.Disconnect << 4), 0})
}()
recv := make(chan []byte)
go func() {
buf, err := ioutil.ReadAll(w)
if err != nil {
panic(err)
}
recv <- buf
}()
//time.Sleep(10 * time.Millisecond)
s.clients.Lock()
for _, v := range s.clients.internal {
v.close()
break
}
s.clients.Unlock()
errx := <-o
require.NoError(t, errx)
require.Equal(t, []byte{
byte(packets.Connack << 4), 2,
0, packets.Accepted,
}, <-recv)
fmt.Println()
w.Close()
}
/*
func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
r2, _ := net.Pipe()
s := New()
px := NewProcessor(r2, circ.NewReader(bufferSize, blockSize), circ.NewWriter(bufferSize, blockSize))
px.Start()
s.clients.internal["zen"] = newClient(
NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
px,
&packets.ConnectPacket{ClientIdentifier: "zen"},
new(auth.Allow),
)
@@ -228,6 +289,8 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
w.Close()
}
/*
func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
r2, w2 := net.Pipe()
s := New()
@@ -496,8 +559,8 @@ func TestResendInflightWriteError(t *testing.T) {
* Server Read Client
*/
*/
/*
func TestServerReadClientOK(t *testing.T) {
s, r, w, cl := setupClient("zen")
@@ -603,8 +666,8 @@ func TestServerReadClientBadPacketValidation(t *testing.T) {
* Server Write Client
*/
*/
/*
func TestServerWriteClient(t *testing.T) {
s, _, _, cl := setupClient("zen")
cl.p.W = new(quietWriter)
@@ -676,8 +739,8 @@ func TestServerWriteClientWriteError(t *testing.T) {
* Server Close Client
*/
*/
/*
func TestServerCloseClient(t *testing.T) { // as opposed to client.close
s, _, _, cl := setupClient("zen")
s.clients.add(cl)
@@ -748,8 +811,8 @@ func TestServerCloseClientLWTWriteError(t *testing.T) { // as opposed to client.
* Server Process Packets
*/
*/
/*
func TestServerProcessPacket(t *testing.T) {
s, _, _, cl := setupClient("zen")
err := s.processPacket(cl, nil)
@@ -1368,3 +1431,4 @@ func TestServerProcessUnsubscribeWriteError(t *testing.T) {
})
require.Error(t, err)
}
*/

192
parser.go
View File

@@ -1,192 +0,0 @@
package mqtt
import (
"bufio"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
"github.com/mochi-co/mqtt/packets"
)
// BufWriter is an interface for satisfying a bufio.Writer. This is mainly
// in place to allow testing.
type BufWriter interface {
// Write writes a byte buffer.
Write(p []byte) (nn int, err error)
// Flush flushes the buffer.
Flush() error
}
// Parser is an MQTT packet parser that reads and writes MQTT payloads to a
// buffered IO stream.
type Parser struct {
sync.RWMutex
// Conn is the net.Conn used to establish the connection.
Conn net.Conn
// R is a bufio reader for peeking and reading incoming packets.
R *bufio.Reader
// W is a bufio writer for writing outgoing packets.
W BufWriter
// FixedHeader is the fixed header from the last seen packet.
FixedHeader packets.FixedHeader
}
// NewParser returns an instance of Parser for a connection.
func NewParser(c net.Conn, r *bufio.Reader, w BufWriter) *Parser {
return &Parser{
Conn: c,
R: r,
W: w,
}
}
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (p *Parser) RefreshDeadline(keepalive uint16) {
if p.Conn != nil {
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
p.Conn.SetDeadline(time.Now().Add(expiry))
}
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
// Peek the maximum message type and flags, and length.
peeked, err := p.R.Peek(1)
if err != nil {
return err
}
// Unpack message type and flags from byte 1.
err = fh.Decode(peeked[0])
if err != nil {
// @SPEC [MQTT-2.2.2-2]
// If invalid flags are received, the receiver MUST close the Network Connection.
return packets.ErrInvalidFlags
}
// The remaining length value can be up to 5 bytes. Peek through each byte
// looking for continue values, and if found increase the peek. Otherwise
// decode the bytes that were legit.
//p.fhBuffer = p.fhBuffer[:0]
buf := make([]byte, 0, 6)
i := 1
b := 2
for ; b < 6; b++ {
peeked, err = p.R.Peek(b)
if err != nil {
return err
}
// Add the byte to the length bytes slice.
buf = append(buf, peeked[i])
// If it's not a continuation flag, end here.
if peeked[i] < 128 {
break
}
// If i has reached 4 without a length terminator, throw a protocol violation.
i++
if i == 4 {
return packets.ErrOversizedLengthIndicator
}
}
// Calculate and store the remaining length.
rem, _ := binary.Uvarint(buf)
fh.Remaining = int(rem)
// Discard the number of used length bytes + first byte.
p.R.Discard(b)
// Set the fixed header in the parser.
p.FixedHeader = *fh
return nil
}
// Read reads the remaining buffer into an MQTT packet.
func (p *Parser) Read() (pk packets.Packet, err error) {
switch p.FixedHeader.Type {
case packets.Connect:
pk = &packets.ConnectPacket{FixedHeader: p.FixedHeader}
case packets.Connack:
pk = &packets.ConnackPacket{FixedHeader: p.FixedHeader}
case packets.Publish:
pk = &packets.PublishPacket{FixedHeader: p.FixedHeader}
case packets.Puback:
pk = &packets.PubackPacket{FixedHeader: p.FixedHeader}
case packets.Pubrec:
pk = &packets.PubrecPacket{FixedHeader: p.FixedHeader}
case packets.Pubrel:
pk = &packets.PubrelPacket{FixedHeader: p.FixedHeader}
case packets.Pubcomp:
pk = &packets.PubcompPacket{FixedHeader: p.FixedHeader}
case packets.Subscribe:
pk = &packets.SubscribePacket{FixedHeader: p.FixedHeader}
case packets.Suback:
pk = &packets.SubackPacket{FixedHeader: p.FixedHeader}
case packets.Unsubscribe:
pk = &packets.UnsubscribePacket{FixedHeader: p.FixedHeader}
case packets.Unsuback:
pk = &packets.UnsubackPacket{FixedHeader: p.FixedHeader}
case packets.Pingreq:
pk = &packets.PingreqPacket{FixedHeader: p.FixedHeader}
case packets.Pingresp:
pk = &packets.PingrespPacket{FixedHeader: p.FixedHeader}
case packets.Disconnect:
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
default:
return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
}
// Attempt to peek the rest of the packet.
peeked := true
bt, err := p.R.Peek(p.FixedHeader.Remaining)
if err != nil {
// Only try to continue if reading is still possible.
if err != bufio.ErrBufferFull {
return pk, err
}
// If it didn't work, read the buffer directly, and if that still doesn't
// work, then throw an error.
peeked = false
bt = make([]byte, p.FixedHeader.Remaining)
_, err := io.ReadFull(p.R, bt)
if err != nil {
return pk, err
}
}
// If peeking was successful, discard the rest of the packet now it's been read.
if peeked {
p.R.Discard(p.FixedHeader.Remaining)
}
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
// 🚨 DANGER!! This line is super important. If the bytes being decoded are not
// in their own memory space, packets will get corrupted all over the place.
err = pk.Decode(append([]byte{}, bt[:]...)) // <--- MUST BE A COPY.
if err != nil {
return pk, err
}
return
}

View File

@@ -1,528 +0,0 @@
package mqtt
import (
"bufio"
"bytes"
"net"
"testing"
//"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/packets"
)
func newBufioReader(c net.Conn) *bufio.Reader {
return bufio.NewReaderSize(c, 512)
}
func newBufioWriter(c net.Conn) *bufio.Writer {
return bufio.NewWriterSize(c, 512)
}
func TestNewParser(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
require.NotNil(t, p.R)
}
func BenchmarkNewParser(b *testing.B) {
conn := new(MockNetConn)
r, w := new(bufio.Reader), new(bufio.Writer)
for n := 0; n < b.N; n++ {
NewParser(conn, r, w)
}
}
func TestRefreshDeadline(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
dl := p.Conn.(*MockNetConn).Deadline
p.RefreshDeadline(10)
require.NotEqual(t, dl, p.Conn.(*MockNetConn).Deadline)
}
func BenchmarkRefreshDeadline(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
for n := 0; n < b.N; n++ {
p.RefreshDeadline(10)
}
}
/*
type fixedHeaderTable struct {
rawBytes []byte
header packets.FixedHeader
packetError bool
flagError bool
}
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{packets.Connect << 4, 0x00},
header: packets.FixedHeader{packets.Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
},
{
rawBytes: []byte{packets.Connack << 4, 0x00},
header: packets.FixedHeader{packets.Connack, false, 0, false, 0},
},
{
rawBytes: []byte{packets.Publish << 4, 0x00},
header: packets.FixedHeader{packets.Publish, false, 0, false, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 1<<1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 1, false, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 1<<1 | 1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 1, true, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 2<<1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 2, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
flagError: true,
},
}
*/
func TestParserReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
// Test null data.
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
fh := new(packets.FixedHeader)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
// Test insufficient peeking.
fh = new(packets.FixedHeader)
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4}))
err = p.ReadFixedHeader(fh)
require.Error(t, err)
fh = new(packets.FixedHeader)
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4, 0x00}))
err = p.ReadFixedHeader(fh)
require.NoError(t, err)
/*
// Test expected bytes.
for i, wanted := range fixedHeaderExpected {
fh := new(packets.FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
b := wanted.rawBytes
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(fh)
if wanted.packetError || wanted.flagError {
require.Error(t, err, "Expected error [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, p.FixedHeader.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, p.FixedHeader.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, p.FixedHeader.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, p.FixedHeader.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
*/
}
func BenchmarkReadFixedHeader(b *testing.B) {
conn := new(MockNetConn)
fh := new(packets.FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
var rn bytes.Reader = *bytes.NewReader([]byte{packets.Connect << 4, 0x00})
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
err := p.ReadFixedHeader(fh)
if err != nil {
panic(err)
}
}
}
func TestRead(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader([]byte{
byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
}))
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
pko, err := p.Read()
require.NoError(t, err)
require.Equal(t, &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
}, pko)
/*
for code, pt := range expectedPackets {
for i, wanted := range pt {
if wanted.primary {
var fh packets.FixedHeader
b := wanted.rawBytes
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(&fh)
if wanted.failFirst != nil {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
}
pko, err := p.Read()
if wanted.expect != nil {
require.Error(t, err, "Expected error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
if err != nil {
require.Equal(t, err, wanted.expect, "Mismatched packet error [i:%d] %s - %s", i, wanted.desc, Names[code])
}
} else {
require.NoError(t, err, "Error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
require.Equal(t, wanted.packet, pko, "Mismatched packet final [i:%d] %s - %s", i, wanted.desc, Names[code])
}
}
}
}
*/
}
func TestReadFail(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader([]byte{
byte(packets.Publish << 4), 3, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/',
}))
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
_, err = p.Read()
require.Error(t, err)
}
func BenchmarkRead(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
pkb := []byte{
byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
}
p.R = bufio.NewReader(bytes.NewReader(pkb))
var fh packets.FixedHeader
err := p.ReadFixedHeader(&fh)
if err != nil {
panic(err)
}
var rn bytes.Reader = *bytes.NewReader(pkb)
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
p.R.Discard(2)
_, err := p.Read()
if err != nil {
panic(err)
}
}
}
// This is a super important test. It checks whether or not subsequent packets
// mutate each other. This happens when you use a single byte buffer for decoding
// multiple packets.
func TestReadPacketNoOverwrite(t *testing.T) {
pk1 := []byte{
byte(packets.Publish << 4), 12, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', // Payload
}
pk2 := []byte{
byte(packets.Publish << 4), 14, // Fixed header
0, 5, // Topic Name - LSB+MSB
'x', '/', 'y', '/', 'z', // Topic Name
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
}
r, w := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
go func() {
w.Write(pk1)
w.Write(pk2)
w.Close()
}()
var fh packets.FixedHeader
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
o1, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:])
require.NoError(t, err)
err = p.ReadFixedHeader(&fh)
require.NoError(t, err)
o2, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*packets.PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload, "o1 payload was mutated")
}
func TestReadPacketNil(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{0, 0}))
err := p.ReadFixedHeader(&fh)
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadOverflow(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 999999 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadAllFail(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 1 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
/*
// MockNetConn satisfies the net.Conn interface.
type MockNetConn struct {
ID string
Deadline time.Time
}
// Read reads bytes from the net io.reader.
func (m *MockNetConn) Read(b []byte) (n int, err error) {
return 0, nil
}
// Read writes bytes to the net io.writer.
func (m *MockNetConn) Write(b []byte) (n int, err error) {
return 0, nil
}
// Close closes the net.Conn connection.
func (m *MockNetConn) Close() error {
return nil
}
// LocalAddr returns the local address of the request.
func (m *MockNetConn) LocalAddr() net.Addr {
return new(MockNetAddr)
}
// RemoteAddr returns the remove address of the request.
func (m *MockNetConn) RemoteAddr() net.Addr {
return new(MockNetAddr)
}
// SetDeadline sets the request deadline.
func (m *MockNetConn) SetDeadline(t time.Time) error {
m.Deadline = t
return nil
}
// SetReadDeadline sets the read deadline.
func (m *MockNetConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline sets the write deadline.
func (m *MockNetConn) SetWriteDeadline(t time.Time) error {
return nil
}
// MockNetAddr satisfies net.Addr interface.
type MockNetAddr struct{}
// Network returns the network protocol.
func (m *MockNetAddr) Network() string {
return "tcp"
}
// String returns the network address.
func (m *MockNetAddr) String() string {
return "127.0.0.1"
}
*/

View File

@@ -1,67 +0,0 @@
package pools
/*
import (
"bufio"
"io"
"sync"
)
// BufioReadersPool is a pool of bufio.Reader.
type BufioReadersPool struct {
pool sync.Pool
}
// NewBufioReadersPool returns a sync.pool of bufio.Reader.
func NewBufioReadersPool(size int) BufioReadersPool {
return BufioReadersPool{
pool: sync.Pool{
New: func() interface{} {
return bufio.NewReaderSize(nil, size)
},
},
}
}
// Get returns a pooled bufio.Reader.
func (b BufioReadersPool) Get(r io.Reader) *bufio.Reader {
buf := b.pool.Get().(*bufio.Reader)
buf.Reset(r)
return buf
}
// Put puts the bufio.Reader back into the pool.
func (b BufioReadersPool) Put(x *bufio.Reader) {
x.Reset(nil)
b.pool.Put(x)
}
// BufioWritersPool is a pool of bufio.Writer.
type BufioWritersPool struct {
pool sync.Pool
}
// NewBufioWritersPool returns a sync.pool of bufio.Writer.
func NewBufioWritersPool(size int) BufioWritersPool {
return BufioWritersPool{
pool: sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, size)
},
},
}
}
// Get returns a pooled bufio.Writer.
func (b BufioWritersPool) Get(r io.Writer) *bufio.Writer {
buf := b.pool.Get().(*bufio.Writer)
buf.Reset(r)
return buf
}
// Put puts the bufio.Writer back into the pool.
func (b BufioWritersPool) Put(x *bufio.Writer) {
x.Reset(nil)
b.pool.Put(x)
}
*/

View File

@@ -1,115 +0,0 @@
package pools
/*
import (
"bufio"
"bytes"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBufioReadersPool(t *testing.T) {
bpool := NewBufioReadersPool(16)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBufioReadersPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBufioReadersPool(16)
}
}
func TestNewBufioReadersGetPut(t *testing.T) {
bpool := NewBufioReadersPool(16)
r := bufio.NewReader(strings.NewReader("mochi"))
out := make([]byte, 5)
buf := bpool.Get(r)
n, err := buf.Read(out)
require.Equal(t, 5, n)
require.NoError(t, err)
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, out)
bpool.Put(buf)
require.Panics(t, func() {
buf.ReadByte()
})
}
func BenchmarkBufioReadersGet(b *testing.B) {
bpool := NewBufioReadersPool(16)
r := bufio.NewReader(strings.NewReader("mochi"))
for n := 0; n < b.N; n++ {
bpool.Get(r)
}
}
func BenchmarkBufioReadersPut(b *testing.B) {
bpool := NewBufioReadersPool(16)
r := bufio.NewReader(strings.NewReader("mochi"))
buf := bpool.Get(r)
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}
func TestNewBufioWritersPool(t *testing.T) {
bpool := NewBufioWritersPool(16)
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBufioWritersPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBufioWritersPool(16)
}
}
func TestNewBufioWritersGetPut(t *testing.T) {
payload := []byte{'m', 'o', 'c', 'h', 'i'}
bpool := NewBufioWritersPool(16)
buf := new(bytes.Buffer)
bw := bufio.NewWriter(buf)
require.NotNil(t, bw)
w := bpool.Get(bw)
require.NotNil(t, w)
n, err := w.Write(payload)
require.NoError(t, err)
require.Equal(t, 5, n)
w.Flush()
bw.Flush()
require.Equal(t, payload, buf.Bytes())
bpool.Put(w)
_, err = w.Write(payload)
require.NoError(t, err)
require.Panics(t, func() {
w.Flush()
})
}
func BenchmarkBufioWritersGet(b *testing.B) {
bpool := NewBufioWritersPool(16)
buf := new(bytes.Buffer)
bw := bufio.NewWriter(buf)
for n := 0; n < b.N; n++ {
bpool.Get(bw)
}
}
func BenchmarkBufioWritersPut(b *testing.B) {
bpool := NewBufioWritersPool(16)
w := bpool.Get(bufio.NewWriter(new(bytes.Buffer)))
for n := 0; n < b.N; n++ {
bpool.Put(w)
}
}
*/

View File

@@ -1,33 +0,0 @@
package pools
import (
"bytes"
"sync"
)
// BytesBuffersPool is a pool of bytes.Buffer.
type BytesBuffersPool struct {
pool sync.Pool
}
// NewBytesBuffersPool returns a sync.pool of bytes.Buffer.
func NewBytesBuffersPool() BytesBuffersPool {
return BytesBuffersPool{
pool: sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
},
}
}
// Get returns a pooled bytes.Buffer.
func (b BytesBuffersPool) Get() *bytes.Buffer {
return b.pool.Get().(*bytes.Buffer)
}
// Put puts the bytes.Buffer back into the pool.
func (b BytesBuffersPool) Put(x *bytes.Buffer) {
x.Reset()
b.pool.Put(x)
}

View File

@@ -1,52 +0,0 @@
package pools
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesBuffersPool(t *testing.T) {
bpool := NewBytesBuffersPool()
require.NotNil(t, bpool.pool)
}
func BenchmarkNewBytesBuffersPool(b *testing.B) {
for n := 0; n < b.N; n++ {
NewBytesBuffersPool()
}
}
func TestNewBytesBuffersPoolGet(t *testing.T) {
bpool := NewBytesBuffersPool()
buf := bpool.Get()
buf.WriteString("mochi")
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, buf.Bytes())
}
func BenchmarkBytesBuffersPoolGet(b *testing.B) {
bpool := NewBytesBuffersPool()
for n := 0; n < b.N; n++ {
bpool.Get()
}
}
func TestNewBytesBuffersPoolPut(t *testing.T) {
bpool := NewBytesBuffersPool()
buf := bpool.Get()
buf.WriteString("mochi")
require.Equal(t, []byte{'m', 'o', 'c', 'h', 'i'}, buf.Bytes())
bpool.Put(buf)
require.Equal(t, 0, buf.Len())
}
func BenchmarkBytesBuffersPoolPut(b *testing.B) {
bpool := NewBytesBuffersPool()
buf := bpool.Get()
for n := 0; n < b.N; n++ {
bpool.Put(buf)
}
}

View File

@@ -2,12 +2,13 @@ package mqtt
import (
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/mochi-co/mqtt/circ"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
)
// Processor reads and writes bytes to a network connection.
@@ -22,6 +23,12 @@ type Processor struct {
// W is a writer for writing outgoing bytes.
W *circ.Writer
// started tracks the goroutines which have been started.
started sync.WaitGroup
// ended tracks the goroutines which have ended.
ended sync.WaitGroup
// FixedHeader is the FixedHeader from the last read packet.
FixedHeader packets.FixedHeader
}
@@ -33,6 +40,51 @@ func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
R: r,
W: w,
}
}
// Start spins up the reader and writer goroutines.
func (p *Processor) Start() {
go func() {
defer p.ended.Done()
fmt.Println("starting readFrom", p.Conn)
p.started.Done()
n, err := p.R.ReadFrom(p.Conn)
if err != nil {
//
}
fmt.Println(">>> finished ReadFrom", n, err)
}()
go func() {
defer p.ended.Done()
fmt.Println("starting writeTo", p.Conn)
p.started.Done()
n, err := p.W.WriteTo(p.Conn)
if err != nil {
//
}
fmt.Println(">>> finished WriteTo", n, err)
}()
p.started.Add(2)
p.ended.Add(2)
p.started.Wait()
}
// Stop stops the processor goroutines.
func (p *Processor) Stop() {
fmt.Println("processor stop")
p.W.Close()
if p.Conn != nil {
p.Conn.Close()
}
p.R.Close()
p.ended.Wait()
}
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
@@ -99,6 +151,8 @@ func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
return err
}
fmt.Println("FIXEDHEADER READ", *fh)
// Set the fixed header in the parser.
p.FixedHeader = *fh
@@ -139,7 +193,7 @@ func (p *Processor) Read() (pk packets.Packet, err error) {
case packets.Disconnect:
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
default:
return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
return pk, fmt.Errorf("No valid packet available; %v", p.FixedHeader.Type)
}
bt, err := p.R.Read(int64(p.FixedHeader.Remaining))

View File

@@ -7,20 +7,20 @@ import (
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/circ"
"github.com/mochi-co/mqtt/packets"
"github.com/mochi-co/mqtt/processors/circ"
)
func TestNewProcessor(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
require.NotNil(t, p.R)
}
func BenchmarkNewProcessor(b *testing.B) {
conn := new(MockNetConn)
r := circ.NewReader(16)
w := circ.NewWriter(16)
r := circ.NewReader(16, 4)
w := circ.NewWriter(16, 4)
for n := 0; n < b.N; n++ {
NewProcessor(conn, r, w)
@@ -29,7 +29,7 @@ func BenchmarkNewProcessor(b *testing.B) {
func TestProcessorRefreshDeadline(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
dl := p.Conn.(*MockNetConn).Deadline
p.RefreshDeadline(10)
@@ -39,21 +39,17 @@ func TestProcessorRefreshDeadline(t *testing.T) {
func BenchmarkProcessorRefreshDeadline(b *testing.B) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
for n := 0; n < b.N; n++ {
p.RefreshDeadline(10)
}
}
func TestProcessorStart(t *testing.T) {
}
func TestProcessorReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
p := NewProcessor(conn, circ.NewReader(16, 4), circ.NewWriter(16, 4))
// Test null data.
fh := new(packets.FixedHeader)
@@ -82,7 +78,7 @@ func TestProcessorRead(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
p.R.Set([]byte{
byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
@@ -108,7 +104,7 @@ func TestProcessorRead(t *testing.T) {
func TestProcessorReadFail(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
p.R.Set([]byte{
byte(packets.Publish << 4), 3, // Fixed header
0, 5, // Topic Name - LSB+MSB
@@ -143,7 +139,7 @@ func TestProcessorReadPacketNoOverwrite(t *testing.T) {
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
}
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
p.R.Set(pk1, 0, len(pk1))
p.R.SetPos(0, int64(len(pk1)))
var fh packets.FixedHeader
@@ -168,7 +164,7 @@ func TestProcessorReadPacketNoOverwrite(t *testing.T) {
func TestProcessorReadPacketNil(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
var fh packets.FixedHeader
// Check for un-specified packet.
@@ -188,7 +184,7 @@ func TestProcessorReadPacketNil(t *testing.T) {
func TestProcessorReadPacketReadOverflow(t *testing.T) {
conn := new(MockNetConn)
p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
p := NewProcessor(conn, circ.NewReader(32, 4), circ.NewWriter(32, 4))
var fh packets.FixedHeader
// Check for un-specified packet.