Move Parser, cleanup to basic

This commit is contained in:
Mochi
2019-10-09 08:30:57 +01:00
parent d5e2312aa7
commit c3af69f714
40 changed files with 732 additions and 709 deletions

View File

@@ -78,7 +78,7 @@ type client struct {
sync.RWMutex
// p is a packets parser which reads incoming packets.
p *packets.Parser
p *Parser
// ac is a pointer to an auth controller inherited from the listener.
ac auth.Controller
@@ -123,7 +123,7 @@ type client struct {
}
// newClient creates a new instance of client.
func newClient(p *packets.Parser, pk *packets.ConnectPacket, ac auth.Controller) *client {
func newClient(p *Parser, pk *packets.ConnectPacket, ac auth.Controller) *client {
cl := &client{
p: p,
ac: ac,

View File

@@ -117,7 +117,7 @@ func BenchmarkClientsGetByListener(b *testing.B) {
func TestNewClient(t *testing.T) {
r, _ := net.Pipe()
p := packets.NewParser(r, newBufioReader(r), newBufioWriter(r))
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
pk := &packets.ConnectPacket{
FixedHeader: packets.FixedHeader{
@@ -154,7 +154,7 @@ func TestNewClient(t *testing.T) {
func TestNewClientLWT(t *testing.T) {
r, _ := net.Pipe()
p := packets.NewParser(r, newBufioReader(r), newBufioWriter(r))
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
pk := &packets.ConnectPacket{
FixedHeader: packets.FixedHeader{
@@ -182,7 +182,7 @@ func TestNewClientLWT(t *testing.T) {
func BenchmarkNewClient(b *testing.B) {
r, _ := net.Pipe()
p := packets.NewParser(r, newBufioReader(r), newBufioWriter(r))
p := NewParser(r, newBufioReader(r), newBufioWriter(r))
r.Close()
pk := new(packets.ConnectPacket)
@@ -245,7 +245,7 @@ func BenchmarkClientForgetSubscription(b *testing.B) {
func TestClientClose(t *testing.T) {
r, w := net.Pipe()
p := packets.NewParser(r, newBufioReader(r), newBufioWriter(w))
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
pk := &packets.ConnectPacket{
ClientIdentifier: "zen3",
}

View File

@@ -105,12 +105,15 @@ func (l *TCP) Serve(establish EstablishFunc) {
}
// Establish connection in a new goroutine.
go func(c net.Conn) {
err := establish(l.id, c, l.config.Auth)
if err != nil {
return
}
}(conn)
go establish(l.id, conn, l.config.Auth)
/*
go func(c net.Conn) {
err := establish(l.id, c, l.config.Auth)
if err != nil {
return
}
}(conn)
*/
}
}
})

143
mqtt.go
View File

@@ -4,10 +4,7 @@ import (
"bufio"
"bytes"
"errors"
"log"
"net"
"runtime"
"sync"
"time"
//"github.com/davecgh/go-spew/spew"
@@ -54,15 +51,6 @@ type Server struct {
// topics is an index of topic subscriptions and retained messages.
topics topics.Indexer
// inbound is a small worker pool which processes incoming packets.
inbound chan transitMsg
// outbound is a small worker pool which processes incoming packets.
outbound chan transitMsg
//inboundPool is a waitgroup for the inbound workers.
inboundPool sync.WaitGroup
}
// transitMsg contains data to be sent to the inbound channel.
@@ -81,72 +69,9 @@ func New() *Server {
listeners: listeners.NewListeners(),
clients: newClients(),
topics: trie.New(),
inbound: make(chan transitMsg),
outbound: make(chan transitMsg),
}
}
func (s *Server) StartProcessing() {
var workers = runtime.NumCPU()
s.inboundPool = sync.WaitGroup{}
s.inboundPool.Add(workers)
for i := 0; i < workers; i++ {
log.Println("spawning worker", i)
go func(wid int) {
defer s.inboundPool.Done()
for {
select {
case v, ok := <-s.inbound:
if !ok {
log.Println("worker inbound closed", wid)
return
}
s.processPacket(v.client, v.packet)
case v, ok := <-s.outbound:
if !ok {
log.Println("worker outbound closed", wid)
return
}
err := s.writeClient(v.client, v.packet)
if err != nil {
log.Println("outbound closing client", v.client.id)
s.closeClient(v.client, true)
}
}
}
}(i)
}
s.inboundPool.Add(workers)
for i := 0; i < workers; i++ {
log.Println("spawning worker", i)
go func(wid int) {
defer s.inboundPool.Done()
for {
select {
case v, ok := <-s.outbound:
if !ok {
log.Println("worker outbound closed", wid)
return
}
err := s.writeClient(v.client, v.packet)
if err != nil {
log.Println("outbound closing client", v.client.id)
s.closeClient(v.client, true)
}
}
}
}(i)
}
log.Println("spawned all workers")
return
}
// AddListener adds a new network listener to the server.
func (s *Server) AddListener(listener listeners.Listener, config *listeners.Config) error {
if _, ok := s.listeners.Get(listener.ID()); ok {
@@ -165,7 +90,6 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
// Serve begins the event loops for establishing client connections on all
// attached listeners.
func (s *Server) Serve() error {
s.StartProcessing()
s.listeners.ServeAll(s.EstablishConnection)
return nil
@@ -176,7 +100,7 @@ func (s *Server) EstablishConnection(lid string, c net.Conn, ac auth.Controller)
// Create a new packets parser which will parse all packets for this client,
// using buffered writers and readers.
p := packets.NewParser(
p := NewParser(
c,
bufio.NewReaderSize(c, rwBufSize),
bufio.NewWriterSize(c, rwBufSize),
@@ -328,9 +252,7 @@ DONE:
}
// Process inbound packet.
s.inbound <- transitMsg{client: cl, packet: pk}
//go s.processPacket(cl, pk)
s.processPacket(cl, pk)
}
}
@@ -485,14 +407,11 @@ func (s *Server) processPublish(cl *client, pk *packets.PublishPacket) error {
}
// Write the publish packet out to the receiving client.
s.outbound <- transitMsg{client: client, packet: out}
/*
err := s.writeClient(client, out)
if err != nil {
s.closeClient(client, true)
return err
}
*/
err := s.writeClient(client, out)
if err != nil {
s.closeClient(client, true)
return err
}
}
}
@@ -572,42 +491,27 @@ func (s *Server) processSubscribe(cl *client, pk *packets.SubscribePacket) error
}
}
s.outbound <- transitMsg{
client: cl,
packet: &packets.SubackPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Suback,
},
PacketID: pk.PacketID,
ReturnCodes: retCodes,
err := s.writeClient(cl, &packets.SubackPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Suback,
},
PacketID: pk.PacketID,
ReturnCodes: retCodes,
})
if err != nil {
s.closeClient(cl, true)
return err
}
/*
err := s.writeClient(cl, &packets.SubackPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Suback,
},
PacketID: pk.PacketID,
ReturnCodes: retCodes,
})
if err != nil {
s.closeClient(cl, true)
return err
}
*/
// Publish out any retained messages matching the subscription filter.
for i := 0; i < len(pk.Topics); i++ {
messages := s.topics.Messages(pk.Topics[i])
for _, pkv := range messages {
s.outbound <- transitMsg{client: cl, packet: pkv}
/*
err := s.writeClient(cl, pkv)
if err != nil {
s.closeClient(cl, true)
return err
}
*/
err := s.writeClient(cl, pkv)
if err != nil {
s.closeClient(cl, true)
return err
}
}
}
@@ -677,11 +581,6 @@ func (s *Server) Close() error {
// Close all listeners.
s.listeners.CloseAll(s.closeListenerClients)
// Close down waitgroups and pools.
close(s.inbound)
close(s.outbound)
s.inboundPool.Wait()
return nil
}

View File

@@ -1,10 +1,8 @@
package mqtt
import (
"bufio"
"errors"
"io/ioutil"
//"log"
"net"
"testing"
"time"
@@ -17,14 +15,6 @@ import (
"github.com/mochi-co/mqtt/topics"
)
func newBufioReader(c net.Conn) *bufio.Reader {
return bufio.NewReaderSize(c, 512)
}
func newBufioWriter(c net.Conn) *bufio.Writer {
return bufio.NewWriterSize(c, 512)
}
type quietWriter struct {
b []byte
f [][]byte
@@ -57,7 +47,7 @@ func setupClient(id string) (s *Server, r net.Conn, w net.Conn, cl *client) {
s = New()
r, w = net.Pipe()
cl = newClient(
packets.NewParser(r, newBufioReader(r), newBufioWriter(w)),
NewParser(r, newBufioReader(r), newBufioWriter(w)),
&packets.ConnectPacket{
ClientIdentifier: id,
},
@@ -77,8 +67,6 @@ func TestNew(t *testing.T) {
require.NotNil(t, s)
require.NotNil(t, s.listeners)
require.NotNil(t, s.clients)
require.NotNil(t, s.inbound)
require.NotNil(t, s.outbound)
// log.Println(s)
}
@@ -88,11 +76,6 @@ func BenchmarkNew(b *testing.B) {
}
}
func TestProcessor(t *testing.T) {
s := New()
}
func TestServerAddListener(t *testing.T) {
s := New()
require.NotNil(t, s)
@@ -198,7 +181,7 @@ func TestServerEstablishConnectionOKCleanSession(t *testing.T) {
r2, w2 := net.Pipe()
s := New()
s.clients.internal["zen"] = newClient(
packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
&packets.ConnectPacket{ClientIdentifier: "zen"},
new(auth.Allow),
)
@@ -249,7 +232,7 @@ func TestServerEstablishConnectionOKInheritSession(t *testing.T) {
r2, w2 := net.Pipe()
s := New()
s.clients.internal["zen"] = newClient(
packets.NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
NewParser(r2, newBufioReader(r2), newBufioWriter(w2)),
&packets.ConnectPacket{ClientIdentifier: "zen"},
new(auth.Allow),
)

View File

@@ -15,7 +15,7 @@ type ConnackPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *ConnackPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.WriteByte(encodeBool(pk.SessionPresent))
buf.WriteByte(pk.ReturnCode)

View File

@@ -57,7 +57,7 @@ func TestConnackDecode(t *testing.T) {
require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Connack).(*ConnackPacket)
pk := &ConnackPacket{FixedHeader: FixedHeader{Type: Connack}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -73,8 +73,8 @@ func TestConnackDecode(t *testing.T) {
}
func BenchmarkConnackDecode(b *testing.B) {
pk := newPacket(Connack).(*ConnackPacket)
pk.FixedHeader.decode(expectedPackets[Connack][0].rawBytes[0])
pk := &ConnackPacket{FixedHeader: FixedHeader{Type: Connack}}
pk.FixedHeader.Decode(expectedPackets[Connack][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Connack][0].rawBytes[2:])
@@ -82,8 +82,8 @@ func BenchmarkConnackDecode(b *testing.B) {
}
func TestConnackValidate(t *testing.T) {
pk := newPacket(Connack).(*ConnackPacket)
pk.FixedHeader.decode(expectedPackets[Connack][0].rawBytes[0])
pk := &ConnackPacket{FixedHeader: FixedHeader{Type: Connack}}
pk.FixedHeader.Decode(expectedPackets[Connack][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -92,8 +92,8 @@ func TestConnackValidate(t *testing.T) {
}
func BenchmarkConnackValidate(b *testing.B) {
pk := newPacket(Connack).(*ConnackPacket)
pk.FixedHeader.decode(expectedPackets[Connack][0].rawBytes[0])
pk := &ConnackPacket{FixedHeader: FixedHeader{Type: Connack}}
pk.FixedHeader.Decode(expectedPackets[Connack][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -58,7 +58,7 @@ func (pk *ConnectPacket) Encode(buf *bytes.Buffer) error {
len(willTopic) + len(willFlag) +
len(usernameFlag) + len(passwordFlag)
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
// Eschew magic for readability.
buf.Write(protoName)

View File

@@ -78,7 +78,7 @@ func TestConnectDecode(t *testing.T) {
require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc)
pk := newPacket(Connect).(*ConnectPacket)
pk := &ConnectPacket{FixedHeader: FixedHeader{Type: Connect}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -113,8 +113,8 @@ func TestConnectDecode(t *testing.T) {
}
func BenchmarkConnectDecode(b *testing.B) {
pk := newPacket(Connect).(*ConnectPacket)
pk.FixedHeader.decode(expectedPackets[Connect][0].rawBytes[0])
pk := &ConnectPacket{FixedHeader: FixedHeader{Type: Connect}}
pk.FixedHeader.Decode(expectedPackets[Connect][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Connect][0].rawBytes[2:])
@@ -133,8 +133,8 @@ func TestConnectValidate(t *testing.T) {
}
func BenchmarkConnectValidate(b *testing.B) {
pk := newPacket(Connect).(*ConnectPacket)
pk.FixedHeader.decode(expectedPackets[Connect][0].rawBytes[0])
pk := &ConnectPacket{FixedHeader: FixedHeader{Type: Connect}}
pk.FixedHeader.Decode(expectedPackets[Connect][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -11,7 +11,7 @@ type DisconnectPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *DisconnectPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
return nil
}

View File

@@ -40,7 +40,7 @@ func BenchmarkDisconnectEncode(b *testing.B) {
}
func TestDisconnectDecode(t *testing.T) {
pk := newPacket(Disconnect).(*DisconnectPacket)
pk := &DisconnectPacket{FixedHeader: FixedHeader{Type: Disconnect}}
var b = []byte{}
err := pk.Decode(b)
@@ -49,8 +49,8 @@ func TestDisconnectDecode(t *testing.T) {
}
func BenchmarkDisconnectDecode(b *testing.B) {
pk := newPacket(Disconnect).(*DisconnectPacket)
pk.FixedHeader.decode(expectedPackets[Disconnect][0].rawBytes[0])
pk := &DisconnectPacket{FixedHeader: FixedHeader{Type: Disconnect}}
pk.FixedHeader.Decode(expectedPackets[Disconnect][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Disconnect][0].rawBytes[2:])
@@ -58,8 +58,8 @@ func BenchmarkDisconnectDecode(b *testing.B) {
}
func TestDisconnectValidate(t *testing.T) {
pk := newPacket(Disconnect).(*DisconnectPacket)
pk.FixedHeader.decode(expectedPackets[Disconnect][0].rawBytes[0])
pk := &DisconnectPacket{FixedHeader: FixedHeader{Type: Disconnect}}
pk.FixedHeader.Decode(expectedPackets[Disconnect][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -68,8 +68,8 @@ func TestDisconnectValidate(t *testing.T) {
}
func BenchmarkDisconnectValidate(b *testing.B) {
pk := newPacket(Disconnect).(*DisconnectPacket)
pk.FixedHeader.decode(expectedPackets[Disconnect][0].rawBytes[0])
pk := &DisconnectPacket{FixedHeader: FixedHeader{Type: Disconnect}}
pk.FixedHeader.Decode(expectedPackets[Disconnect][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -24,13 +24,13 @@ type FixedHeader struct {
}
// encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) encode(buf *bytes.Buffer) {
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, fh.Remaining)
}
// decode extracts the specification bits from the header byte.
func (fh *FixedHeader) decode(headerByte byte) error {
func (fh *FixedHeader) Decode(headerByte byte) error {
// Get the message type from the first 4 bytes.
fh.Type = headerByte >> 4

View File

@@ -149,7 +149,7 @@ var fixedHeaderExpected = []fixedHeaderTable{
func TestFixedHeaderEncode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
buf := new(bytes.Buffer)
wanted.header.encode(buf)
wanted.header.Encode(buf)
if wanted.flagError == false {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes)
require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes)
@@ -160,14 +160,14 @@ func TestFixedHeaderEncode(t *testing.T) {
func BenchmarkFixedHeaderEncode(b *testing.B) {
buf := new(bytes.Buffer)
for n := 0; n < b.N; n++ {
fixedHeaderExpected[0].header.encode(buf)
fixedHeaderExpected[0].header.Encode(buf)
}
}
func TestFixedHeaderDecode(t *testing.T) {
for i, wanted := range fixedHeaderExpected {
fh := new(FixedHeader)
err := fh.decode(wanted.rawBytes[0])
err := fh.Decode(wanted.rawBytes[0])
if wanted.flagError {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
} else {
@@ -183,7 +183,7 @@ func TestFixedHeaderDecode(t *testing.T) {
func BenchmarkFixedHeaderDecode(b *testing.B) {
fh := new(FixedHeader)
for n := 0; n < b.N; n++ {
err := fh.decode(fixedHeaderExpected[0].rawBytes[0])
err := fh.Decode(fixedHeaderExpected[0].rawBytes[0])
if err != nil {
panic(err)
}

View File

@@ -1,5 +1,9 @@
package packets
import (
"bytes"
)
// All of the valid packet types and their packet identifier.
const (
Reserved byte = iota
@@ -39,41 +43,15 @@ var Names = map[byte]string{
14: "DISCONNECT",
}
// newPacket returns a packet of a specified packetType.
// this is a convenience package for testing and shouldn't be used for production
// code.
func newPacket(packetType byte) Packet {
// Packet is the base interface that all MQTT packets must implement.
type Packet interface {
switch packetType {
case Connect:
return &ConnectPacket{FixedHeader: FixedHeader{Type: Connect}}
case Connack:
return &ConnackPacket{FixedHeader: FixedHeader{Type: Connack}}
case Publish:
return &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
case Puback:
return &PubackPacket{FixedHeader: FixedHeader{Type: Puback}}
case Pubrec:
return &PubrecPacket{FixedHeader: FixedHeader{Type: Pubrec}}
case Pubrel:
return &PubrelPacket{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}}
case Pubcomp:
return &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
case Subscribe:
return &SubscribePacket{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}}
case Suback:
return &SubackPacket{FixedHeader: FixedHeader{Type: Suback}}
case Unsubscribe:
return &UnsubscribePacket{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}}
case Unsuback:
return &UnsubackPacket{FixedHeader: FixedHeader{Type: Unsuback}}
case Pingreq:
return &PingreqPacket{FixedHeader: FixedHeader{Type: Pingreq}}
case Pingresp:
return &PingrespPacket{FixedHeader: FixedHeader{Type: Pingresp}}
case Disconnect:
return &DisconnectPacket{FixedHeader: FixedHeader{Type: Disconnect}}
}
return nil
// Encode encodes a packet into a byte buffer.
Encode(*bytes.Buffer) error
// Decode decodes a byte array into a packet struct.
Decode([]byte) error
// Validate the packet. Returns a error code and error if not valid.
Validate() (byte, error)
}

View File

@@ -1,5 +1,6 @@
package packets
/*
import (
"testing"
@@ -12,3 +13,4 @@ func TestNewPacket(t *testing.T) {
pk := newPacket(99)
require.Equal(t, nil, pk, "Returned packet should be nil")
}
*/

View File

@@ -1,342 +0,0 @@
package packets
import (
"bufio"
"bytes"
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func newBufioReader(c io.Reader) *bufio.Reader {
return bufio.NewReaderSize(c, 512)
}
func newBufioWriter(c io.Writer) *bufio.Writer {
return bufio.NewWriterSize(c, 512)
}
func TestNewParser(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
require.NotNil(t, p.R)
}
func BenchmarkNewParser(b *testing.B) {
conn := new(MockNetConn)
r, w := new(bufio.Reader), new(bufio.Writer)
for n := 0; n < b.N; n++ {
NewParser(conn, r, w)
}
}
func TestRefreshDeadline(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
dl := p.Conn.(*MockNetConn).Deadline
p.RefreshDeadline(10)
require.NotEqual(t, dl, p.Conn.(*MockNetConn).Deadline)
}
func BenchmarkRefreshDeadline(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
for n := 0; n < b.N; n++ {
p.RefreshDeadline(10)
}
}
func TestReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
// Test null data.
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
fh := new(FixedHeader)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
// Test insufficient peeking.
fh = new(FixedHeader)
p.R = bufio.NewReader(bytes.NewReader([]byte{Connect << 4}))
err = p.ReadFixedHeader(fh)
require.Error(t, err)
// Test expected bytes.
for i, wanted := range fixedHeaderExpected {
fh := new(FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
b := wanted.rawBytes
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(fh)
if wanted.packetError || wanted.flagError {
require.Error(t, err, "Expected error [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, p.FixedHeader.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, p.FixedHeader.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, p.FixedHeader.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, p.FixedHeader.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
}
func BenchmarkReadFixedHeader(b *testing.B) {
conn := new(MockNetConn)
fh := new(FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
var rn bytes.Reader = *bytes.NewReader(fixedHeaderExpected[0].rawBytes)
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
err := p.ReadFixedHeader(fh)
if err != nil {
panic(err)
}
}
}
func TestRead(t *testing.T) {
conn := new(MockNetConn)
for code, pt := range expectedPackets {
for i, wanted := range pt {
if wanted.primary {
var fh FixedHeader
b := wanted.rawBytes
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(&fh)
if wanted.failFirst != nil {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
}
pko, err := p.Read()
if wanted.expect != nil {
require.Error(t, err, "Expected error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
if err != nil {
require.Equal(t, err, wanted.expect, "Mismatched packet error [i:%d] %s - %s", i, wanted.desc, Names[code])
}
} else {
require.NoError(t, err, "Error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
require.Equal(t, wanted.packet, pko, "Mismatched packet final [i:%d] %s - %s", i, wanted.desc, Names[code])
}
}
}
}
// Fail decoder
var fh FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader([]byte{
byte(Publish << 4), 3, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/',
}))
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
_, err = p.Read()
require.Error(t, err)
}
func BenchmarkRead(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader(expectedPackets[Publish][1].rawBytes))
var fh FixedHeader
err := p.ReadFixedHeader(&fh)
if err != nil {
panic(err)
}
var rn bytes.Reader = *bytes.NewReader(expectedPackets[Publish][1].rawBytes)
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
p.R.Discard(2)
_, err := p.Read()
if err != nil {
panic(err)
}
}
}
// This is a super important test. It checks whether or not subsequent packets
// mutate each other. This happens when you use a single byte buffer for decoding
// multiple packets.
func TestReadPacketNoOverwrite(t *testing.T) {
pk1 := []byte{
byte(Publish << 4), 12, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', // Payload
}
pk2 := []byte{
byte(Publish << 4), 14, // Fixed header
0, 5, // Topic Name - LSB+MSB
'x', '/', 'y', '/', 'z', // Topic Name
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
}
r, w := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
go func() {
w.Write(pk1)
w.Write(pk2)
w.Close()
}()
var fh FixedHeader
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
o1, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:])
require.NoError(t, err)
err = p.ReadFixedHeader(&fh)
require.NoError(t, err)
o2, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*PublishPacket).Payload, "o1 payload was mutated")
}
func TestReadPacketNil(t *testing.T) {
conn := new(MockNetConn)
var fh FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{0, 0}))
err := p.ReadFixedHeader(&fh)
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadOverflow(t *testing.T) {
conn := new(MockNetConn)
var fh FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 999999 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadAllFail(t *testing.T) {
conn := new(MockNetConn)
var fh FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 1 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
// MockNetConn satisfies the net.Conn interface.
type MockNetConn struct {
ID string
Deadline time.Time
}
// Read reads bytes from the net io.reader.
func (m *MockNetConn) Read(b []byte) (n int, err error) {
return 0, nil
}
// Read writes bytes to the net io.writer.
func (m *MockNetConn) Write(b []byte) (n int, err error) {
return 0, nil
}
// Close closes the net.Conn connection.
func (m *MockNetConn) Close() error {
return nil
}
// LocalAddr returns the local address of the request.
func (m *MockNetConn) LocalAddr() net.Addr {
return new(MockNetAddr)
}
// RemoteAddr returns the remove address of the request.
func (m *MockNetConn) RemoteAddr() net.Addr {
return new(MockNetAddr)
}
// SetDeadline sets the request deadline.
func (m *MockNetConn) SetDeadline(t time.Time) error {
m.Deadline = t
return nil
}
// SetReadDeadline sets the read deadline.
func (m *MockNetConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline sets the write deadline.
func (m *MockNetConn) SetWriteDeadline(t time.Time) error {
return nil
}
// MockNetAddr satisfies net.Addr interface.
type MockNetAddr struct{}
// Network returns the network protocol.
func (m *MockNetAddr) Network() string {
return "tcp"
}
// String returns the network address.
func (m *MockNetAddr) String() string {
return "127.0.0.1"
}

View File

@@ -11,7 +11,7 @@ type PingreqPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PingreqPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
return nil
}

View File

@@ -42,7 +42,7 @@ func BenchmarkPingreqEncode(b *testing.B) {
}
func TestPingreqDecode(t *testing.T) {
pk := newPacket(Pingreq).(*PingreqPacket)
pk := &PingreqPacket{FixedHeader: FixedHeader{Type: Pingreq}}
var b = []byte{}
err := pk.Decode(b)
@@ -51,8 +51,8 @@ func TestPingreqDecode(t *testing.T) {
}
func BenchmarkPingreqDecode(b *testing.B) {
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.decode(expectedPackets[Pingreq][0].rawBytes[0])
pk := &PingreqPacket{FixedHeader: FixedHeader{Type: Pingreq}}
pk.FixedHeader.Decode(expectedPackets[Pingreq][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Pingreq][0].rawBytes[2:])
@@ -60,8 +60,8 @@ func BenchmarkPingreqDecode(b *testing.B) {
}
func TestPingreqValidate(t *testing.T) {
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.decode(expectedPackets[Pingreq][0].rawBytes[0])
pk := &PingreqPacket{FixedHeader: FixedHeader{Type: Pingreq}}
pk.FixedHeader.Decode(expectedPackets[Pingreq][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -70,8 +70,8 @@ func TestPingreqValidate(t *testing.T) {
}
func BenchmarkPingreqValidate(b *testing.B) {
pk := newPacket(Pingreq).(*PingreqPacket)
pk.FixedHeader.decode(expectedPackets[Pingreq][0].rawBytes[0])
pk := &PingreqPacket{FixedHeader: FixedHeader{Type: Pingreq}}
pk.FixedHeader.Decode(expectedPackets[Pingreq][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -11,7 +11,7 @@ type PingrespPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PingrespPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
return nil
}

View File

@@ -41,7 +41,7 @@ func BenchmarkPingrespEncode(b *testing.B) {
}
func TestPingrespDecode(t *testing.T) {
pk := newPacket(Pingresp).(*PingrespPacket)
pk := &PingrespPacket{FixedHeader: FixedHeader{Type: Pingresp}}
var b = []byte{}
err := pk.Decode(b)
@@ -50,8 +50,8 @@ func TestPingrespDecode(t *testing.T) {
}
func BenchmarkPingrespDecode(b *testing.B) {
pk := newPacket(Pingresp).(*PingrespPacket)
pk.FixedHeader.decode(expectedPackets[Pingresp][0].rawBytes[0])
pk := &PingrespPacket{FixedHeader: FixedHeader{Type: Pingresp}}
pk.FixedHeader.Decode(expectedPackets[Pingresp][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Pingresp][0].rawBytes[2:])
@@ -59,8 +59,8 @@ func BenchmarkPingrespDecode(b *testing.B) {
}
func TestPingrespValidate(t *testing.T) {
pk := newPacket(Pingresp).(*PingrespPacket)
pk.FixedHeader.decode(expectedPackets[Pingresp][0].rawBytes[0])
pk := &PingrespPacket{FixedHeader: FixedHeader{Type: Pingresp}}
pk.FixedHeader.Decode(expectedPackets[Pingresp][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -69,8 +69,8 @@ func TestPingrespValidate(t *testing.T) {
}
func BenchmarkPingrespValidate(b *testing.B) {
pk := newPacket(Pingresp).(*PingrespPacket)
pk.FixedHeader.decode(expectedPackets[Pingresp][0].rawBytes[0])
pk := &PingrespPacket{FixedHeader: FixedHeader{Type: Pingresp}}
pk.FixedHeader.Decode(expectedPackets[Pingresp][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -14,7 +14,7 @@ type PubackPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PubackPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}

View File

@@ -55,7 +55,7 @@ func TestPubackDecode(t *testing.T) {
require.Equal(t, uint8(4), Puback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Puback).(*PubackPacket)
pk := &PubackPacket{FixedHeader: FixedHeader{Type: Puback}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
@@ -71,8 +71,8 @@ func TestPubackDecode(t *testing.T) {
}
func BenchmarkPubackDecode(b *testing.B) {
pk := newPacket(Puback).(*PubackPacket)
pk.FixedHeader.decode(expectedPackets[Puback][0].rawBytes[0])
pk := &PubackPacket{FixedHeader: FixedHeader{Type: Puback}}
pk.FixedHeader.Decode(expectedPackets[Puback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Puback][0].rawBytes[2:])
@@ -80,8 +80,8 @@ func BenchmarkPubackDecode(b *testing.B) {
}
func TestPubackValidate(t *testing.T) {
pk := newPacket(Puback).(*PubackPacket)
pk.FixedHeader.decode(expectedPackets[Puback][0].rawBytes[0])
pk := &PubackPacket{FixedHeader: FixedHeader{Type: Puback}}
pk.FixedHeader.Decode(expectedPackets[Puback][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -90,8 +90,8 @@ func TestPubackValidate(t *testing.T) {
}
func BenchmarkPubackValidate(b *testing.B) {
pk := newPacket(Puback).(*PubackPacket)
pk.FixedHeader.decode(expectedPackets[Puback][0].rawBytes[0])
pk := &PubackPacket{FixedHeader: FixedHeader{Type: Puback}}
pk.FixedHeader.Decode(expectedPackets[Puback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -14,7 +14,7 @@ type PubcompPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PubcompPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}

View File

@@ -36,7 +36,7 @@ func TestPubcompEncode(t *testing.T) {
}
func BenchmarkPubcompEncode(b *testing.B) {
pk := new(PubcompPacket)
pk := &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
copier.Copy(pk, expectedPackets[Pubcomp][0].packet.(*PubcompPacket))
buf := new(bytes.Buffer)
@@ -55,7 +55,7 @@ func TestPubcompDecode(t *testing.T) {
require.Equal(t, uint8(7), Pubcomp, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Pubcomp).(*PubcompPacket)
pk := &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
@@ -71,8 +71,8 @@ func TestPubcompDecode(t *testing.T) {
}
func BenchmarkPubcompDecode(b *testing.B) {
pk := newPacket(Pubcomp).(*PubcompPacket)
pk.FixedHeader.decode(expectedPackets[Pubcomp][0].rawBytes[0])
pk := &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
pk.FixedHeader.Decode(expectedPackets[Pubcomp][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Pubcomp][0].rawBytes[2:])
@@ -80,8 +80,8 @@ func BenchmarkPubcompDecode(b *testing.B) {
}
func TestPubcompValidate(t *testing.T) {
pk := newPacket(Pubcomp).(*PubcompPacket)
pk.FixedHeader.decode(expectedPackets[Pubcomp][0].rawBytes[0])
pk := &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
pk.FixedHeader.Decode(expectedPackets[Pubcomp][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -90,8 +90,8 @@ func TestPubcompValidate(t *testing.T) {
}
func BenchmarkPubcompValidate(b *testing.B) {
pk := newPacket(Pubcomp).(*PubcompPacket)
pk.FixedHeader.decode(expectedPackets[Pubcomp][0].rawBytes[0])
pk := &PubcompPacket{FixedHeader: FixedHeader{Type: Pubcomp}}
pk.FixedHeader.Decode(expectedPackets[Pubcomp][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -31,7 +31,7 @@ func (pk *PublishPacket) Encode(buf *bytes.Buffer) error {
}
pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload)
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(topicName)
buf.Write(packetID)
buf.Write(pk.Payload)

View File

@@ -71,8 +71,8 @@ func TestPublishDecode(t *testing.T) {
require.Equal(t, uint8(3), Publish, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Publish).(*PublishPacket)
pk.FixedHeader.decode(wanted.rawBytes[0])
pk := &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
pk.FixedHeader.Decode(wanted.rawBytes[0])
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
@@ -91,8 +91,8 @@ func TestPublishDecode(t *testing.T) {
}
func BenchmarkPublishDecode(b *testing.B) {
pk := newPacket(Publish).(*PublishPacket)
pk.FixedHeader.decode(expectedPackets[Publish][1].rawBytes[0])
pk := &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Publish][1].rawBytes[2:])
@@ -104,7 +104,7 @@ func TestPublishCopy(t *testing.T) {
for i, wanted := range expectedPackets[Publish] {
if wanted.group == "copy" {
pk := newPacket(Publish).(*PublishPacket)
pk := &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -122,8 +122,8 @@ func TestPublishCopy(t *testing.T) {
}
func BenchmarkPublishCopy(b *testing.B) {
pk := newPacket(Publish).(*PublishPacket)
pk.FixedHeader.decode(expectedPackets[Publish][1].rawBytes[0])
pk := &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Copy()
@@ -151,8 +151,8 @@ func TestPublishValidate(t *testing.T) {
}
func BenchmarkPublishValidate(b *testing.B) {
pk := newPacket(Publish).(*PublishPacket)
pk.FixedHeader.decode(expectedPackets[Publish][1].rawBytes[0])
pk := &PublishPacket{FixedHeader: FixedHeader{Type: Publish}}
pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0])
for n := 0; n < b.N; n++ {
_, err := pk.Validate()

View File

@@ -14,7 +14,7 @@ type PubrecPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PubrecPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}

View File

@@ -57,7 +57,7 @@ func TestPubrecDecode(t *testing.T) {
require.Equal(t, uint8(5), Pubrec, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Pubrec).(*PubrecPacket)
pk := &PubrecPacket{FixedHeader: FixedHeader{Type: Pubrec}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
@@ -74,8 +74,8 @@ func TestPubrecDecode(t *testing.T) {
}
func BenchmarkPubrecDecode(b *testing.B) {
pk := newPacket(Pubrec).(*PubrecPacket)
pk.FixedHeader.decode(expectedPackets[Pubrec][0].rawBytes[0])
pk := &PubrecPacket{FixedHeader: FixedHeader{Type: Pubrec}}
pk.FixedHeader.Decode(expectedPackets[Pubrec][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Pubrec][0].rawBytes[2:])
@@ -83,8 +83,8 @@ func BenchmarkPubrecDecode(b *testing.B) {
}
func TestPubrecValidate(t *testing.T) {
pk := newPacket(Pubrec).(*PubrecPacket)
pk.FixedHeader.decode(expectedPackets[Pubrec][0].rawBytes[0])
pk := &PubrecPacket{FixedHeader: FixedHeader{Type: Pubrec}}
pk.FixedHeader.Decode(expectedPackets[Pubrec][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -93,8 +93,8 @@ func TestPubrecValidate(t *testing.T) {
}
func BenchmarkPubrecValidate(b *testing.B) {
pk := newPacket(Pubrec).(*PubrecPacket)
pk.FixedHeader.decode(expectedPackets[Pubrec][0].rawBytes[0])
pk := &PubrecPacket{FixedHeader: FixedHeader{Type: Pubrec}}
pk.FixedHeader.Decode(expectedPackets[Pubrec][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -14,7 +14,7 @@ type PubrelPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *PubrelPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}

View File

@@ -59,7 +59,7 @@ func TestPubrelDecode(t *testing.T) {
require.Equal(t, uint8(6), Pubrel, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Pubrel).(*PubrelPacket)
pk := &PubrelPacket{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
@@ -74,8 +74,8 @@ func TestPubrelDecode(t *testing.T) {
}
func BenchmarkPubrelDecode(b *testing.B) {
pk := newPacket(Pubrel).(*PubrelPacket)
pk.FixedHeader.decode(expectedPackets[Pubrel][0].rawBytes[0])
pk := &PubrelPacket{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Pubrel][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Pubrel][0].rawBytes[2:])
@@ -83,8 +83,8 @@ func BenchmarkPubrelDecode(b *testing.B) {
}
func TestPubrelValidate(t *testing.T) {
pk := newPacket(Pubrel).(*PubrelPacket)
pk.FixedHeader.decode(expectedPackets[Pubrel][0].rawBytes[0])
pk := &PubrelPacket{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Pubrel][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -93,8 +93,8 @@ func TestPubrelValidate(t *testing.T) {
}
func BenchmarkPubrelValidate(b *testing.B) {
pk := newPacket(Pubrel).(*PubrelPacket)
pk.FixedHeader.decode(expectedPackets[Pubrel][0].rawBytes[0])
pk := &PubrelPacket{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Pubrel][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -17,7 +17,7 @@ func (pk *SubackPacket) Encode(buf *bytes.Buffer) error {
packetID := encodeUint16(pk.PacketID)
pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length.
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(packetID) // Encode Packet ID.
buf.Write(pk.ReturnCodes) // Encode granted QOS flags.

View File

@@ -62,7 +62,7 @@ func TestSubackDecode(t *testing.T) {
require.Equal(t, uint8(9), Suback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Suback).(*SubackPacket)
pk := &SubackPacket{FixedHeader: FixedHeader{Type: Suback}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -78,8 +78,8 @@ func TestSubackDecode(t *testing.T) {
}
func BenchmarkSubackDecode(b *testing.B) {
pk := newPacket(Suback).(*SubackPacket)
pk.FixedHeader.decode(expectedPackets[Suback][0].rawBytes[0])
pk := &SubackPacket{FixedHeader: FixedHeader{Type: Suback}}
pk.FixedHeader.Decode(expectedPackets[Suback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Suback][0].rawBytes[2:])
@@ -87,8 +87,8 @@ func BenchmarkSubackDecode(b *testing.B) {
}
func TestSubackValidate(t *testing.T) {
pk := newPacket(Suback).(*SubackPacket)
pk.FixedHeader.decode(expectedPackets[Suback][0].rawBytes[0])
pk := &SubackPacket{FixedHeader: FixedHeader{Type: Suback}}
pk.FixedHeader.Decode(expectedPackets[Suback][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -97,8 +97,8 @@ func TestSubackValidate(t *testing.T) {
}
func BenchmarkSubackValidate(b *testing.B) {
pk := newPacket(Suback).(*SubackPacket)
pk.FixedHeader.decode(expectedPackets[Suback][0].rawBytes[0])
pk := &SubackPacket{FixedHeader: FixedHeader{Type: Suback}}
pk.FixedHeader.Decode(expectedPackets[Suback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -31,7 +31,7 @@ func (pk *SubscribePacket) Encode(buf *bytes.Buffer) error {
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names and associated QOS flags.

View File

@@ -65,7 +65,7 @@ func TestSubscribeDecode(t *testing.T) {
require.Equal(t, uint8(8), Subscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Subscribe).(*SubscribePacket)
pk := &SubscribePacket{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -82,8 +82,8 @@ func TestSubscribeDecode(t *testing.T) {
}
func BenchmarkSubscribeDecode(b *testing.B) {
pk := newPacket(Subscribe).(*SubscribePacket)
pk.FixedHeader.decode(expectedPackets[Subscribe][0].rawBytes[0])
pk := &SubscribePacket{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Subscribe][0].rawBytes[2:])
@@ -91,13 +91,6 @@ func BenchmarkSubscribeDecode(b *testing.B) {
}
func TestSubscribeValidate(t *testing.T) {
/*pk := newPacket(Subscribe).(*SubscribePacket)
pk.FixedHeader.decode(expectedPackets[Subscribe][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
require.Equal(t, Accepted, b)
*/
require.Contains(t, expectedPackets, Subscribe)
for i, wanted := range expectedPackets[Subscribe] {
if wanted.group == "validate" || i == 0 {
@@ -118,8 +111,8 @@ func TestSubscribeValidate(t *testing.T) {
}
func BenchmarkSubscribeValidate(b *testing.B) {
pk := newPacket(Subscribe).(*SubscribePacket)
pk.FixedHeader.decode(expectedPackets[Subscribe][0].rawBytes[0])
pk := &SubscribePacket{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -14,7 +14,7 @@ type UnsubackPacket struct {
// Encode encodes and writes the packet data values to the buffer.
func (pk *UnsubackPacket) Encode(buf *bytes.Buffer) error {
pk.FixedHeader.Remaining = 2
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(encodeUint16(pk.PacketID))
return nil
}

View File

@@ -59,7 +59,7 @@ func TestUnsubackDecode(t *testing.T) {
require.Equal(t, uint8(11), Unsuback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Unsuback).(*UnsubackPacket)
pk := &UnsubackPacket{FixedHeader: FixedHeader{Type: Unsuback}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -74,8 +74,8 @@ func TestUnsubackDecode(t *testing.T) {
}
func BenchmarkUnsubackDecode(b *testing.B) {
pk := newPacket(Unsuback).(*UnsubackPacket)
pk.FixedHeader.decode(expectedPackets[Unsuback][0].rawBytes[0])
pk := &UnsubackPacket{FixedHeader: FixedHeader{Type: Unsuback}}
pk.FixedHeader.Decode(expectedPackets[Unsuback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Unsuback][0].rawBytes[2:])
@@ -83,8 +83,8 @@ func BenchmarkUnsubackDecode(b *testing.B) {
}
func TestUnsubackValidate(t *testing.T) {
pk := newPacket(Unsuback).(*UnsubackPacket)
pk.FixedHeader.decode(expectedPackets[Unsuback][0].rawBytes[0])
pk := &UnsubackPacket{FixedHeader: FixedHeader{Type: Unsuback}}
pk.FixedHeader.Decode(expectedPackets[Unsuback][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
@@ -93,8 +93,8 @@ func TestUnsubackValidate(t *testing.T) {
}
func BenchmarkUnsubackValidate(b *testing.B) {
pk := newPacket(Unsuback).(*UnsubackPacket)
pk.FixedHeader.decode(expectedPackets[Unsuback][0].rawBytes[0])
pk := &UnsubackPacket{FixedHeader: FixedHeader{Type: Unsuback}}
pk.FixedHeader.Decode(expectedPackets[Unsuback][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -30,7 +30,7 @@ func (pk *UnsubscribePacket) Encode(buf *bytes.Buffer) error {
}
pk.FixedHeader.Remaining = len(packetID) + topicsLen
pk.FixedHeader.encode(buf)
pk.FixedHeader.Encode(buf)
buf.Write(packetID)
// Add all provided topic names.

View File

@@ -66,7 +66,7 @@ func TestUnsubscribeDecode(t *testing.T) {
require.Equal(t, uint8(10), Unsubscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc)
pk := newPacket(Unsubscribe).(*UnsubscribePacket)
pk := &UnsubscribePacket{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}}
err := pk.Decode(wanted.rawBytes[2:]) // Unpack skips fixedheader.
if wanted.failFirst != nil {
require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc)
@@ -82,8 +82,8 @@ func TestUnsubscribeDecode(t *testing.T) {
}
func BenchmarkUnsubscribeDecode(b *testing.B) {
pk := newPacket(Unsubscribe).(*UnsubscribePacket)
pk.FixedHeader.decode(expectedPackets[Unsubscribe][0].rawBytes[0])
pk := &UnsubscribePacket{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Decode(expectedPackets[Unsubscribe][0].rawBytes[2:])
@@ -91,13 +91,6 @@ func BenchmarkUnsubscribeDecode(b *testing.B) {
}
func TestUnsubscribeValidate(t *testing.T) {
/*pk := newPacket(Unsubscribe).(*UnsubscribePacket)
pk.FixedHeader.decode(expectedPackets[Unsubscribe][0].rawBytes[0])
b, err := pk.Validate()
require.NoError(t, err)
require.Equal(t, Accepted, b)
*/
require.Contains(t, expectedPackets, Unsubscribe)
for i, wanted := range expectedPackets[Unsubscribe] {
if wanted.group == "validate" || i == 0 {
@@ -118,8 +111,8 @@ func TestUnsubscribeValidate(t *testing.T) {
}
func BenchmarkUnsubscribeValidate(b *testing.B) {
pk := newPacket(Unsubscribe).(*UnsubscribePacket)
pk.FixedHeader.decode(expectedPackets[Unsubscribe][0].rawBytes[0])
pk := &UnsubscribePacket{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}}
pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0])
for n := 0; n < b.N; n++ {
pk.Validate()

View File

@@ -1,29 +1,17 @@
package packets
package mqtt
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
"github.com/mochi-co/mqtt/packets"
)
// Packet is the base interface that all MQTT packets must implement.
type Packet interface {
// Encode encodes a packet into a byte buffer.
Encode(*bytes.Buffer) error
// Decode decodes a byte array into a packet struct.
Decode([]byte) error
// Validate the packet. Returns a error code and error if not valid.
Validate() (byte, error)
}
// BufWriter is an interface for satisfying a bufio.Writer. This is mainly
// in place to allow testing.
type BufWriter interface {
@@ -50,7 +38,7 @@ type Parser struct {
W BufWriter
// FixedHeader is the fixed header from the last seen packet.
FixedHeader FixedHeader
FixedHeader packets.FixedHeader
}
// NewParser returns an instance of Parser for a connection.
@@ -71,7 +59,7 @@ func (p *Parser) RefreshDeadline(keepalive uint16) {
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (p *Parser) ReadFixedHeader(fh *FixedHeader) error {
func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
// Peek the maximum message type and flags, and length.
peeked, err := p.R.Peek(1)
@@ -80,12 +68,12 @@ func (p *Parser) ReadFixedHeader(fh *FixedHeader) error {
}
// Unpack message type and flags from byte 1.
err = fh.decode(peeked[0])
err = fh.Decode(peeked[0])
if err != nil {
// @SPEC [MQTT-2.2.2-2]
// If invalid flags are received, the receiver MUST close the Network Connection.
return ErrInvalidFlags
return packets.ErrInvalidFlags
}
// The remaining length value can be up to 5 bytes. Peek through each byte
@@ -112,7 +100,7 @@ func (p *Parser) ReadFixedHeader(fh *FixedHeader) error {
// If i has reached 4 without a length terminator, throw a protocol violation.
i++
if i == 4 {
return ErrOversizedLengthIndicator
return packets.ErrOversizedLengthIndicator
}
}
@@ -130,37 +118,37 @@ func (p *Parser) ReadFixedHeader(fh *FixedHeader) error {
}
// Read reads the remaining buffer into an MQTT packet.
func (p *Parser) Read() (pk Packet, err error) {
func (p *Parser) Read() (pk packets.Packet, err error) {
switch p.FixedHeader.Type {
case Connect:
pk = &ConnectPacket{FixedHeader: p.FixedHeader}
case Connack:
pk = &ConnackPacket{FixedHeader: p.FixedHeader}
case Publish:
pk = &PublishPacket{FixedHeader: p.FixedHeader}
case Puback:
pk = &PubackPacket{FixedHeader: p.FixedHeader}
case Pubrec:
pk = &PubrecPacket{FixedHeader: p.FixedHeader}
case Pubrel:
pk = &PubrelPacket{FixedHeader: p.FixedHeader}
case Pubcomp:
pk = &PubcompPacket{FixedHeader: p.FixedHeader}
case Subscribe:
pk = &SubscribePacket{FixedHeader: p.FixedHeader}
case Suback:
pk = &SubackPacket{FixedHeader: p.FixedHeader}
case Unsubscribe:
pk = &UnsubscribePacket{FixedHeader: p.FixedHeader}
case Unsuback:
pk = &UnsubackPacket{FixedHeader: p.FixedHeader}
case Pingreq:
pk = &PingreqPacket{FixedHeader: p.FixedHeader}
case Pingresp:
pk = &PingrespPacket{FixedHeader: p.FixedHeader}
case Disconnect:
pk = &DisconnectPacket{FixedHeader: p.FixedHeader}
case packets.Connect:
pk = &packets.ConnectPacket{FixedHeader: p.FixedHeader}
case packets.Connack:
pk = &packets.ConnackPacket{FixedHeader: p.FixedHeader}
case packets.Publish:
pk = &packets.PublishPacket{FixedHeader: p.FixedHeader}
case packets.Puback:
pk = &packets.PubackPacket{FixedHeader: p.FixedHeader}
case packets.Pubrec:
pk = &packets.PubrecPacket{FixedHeader: p.FixedHeader}
case packets.Pubrel:
pk = &packets.PubrelPacket{FixedHeader: p.FixedHeader}
case packets.Pubcomp:
pk = &packets.PubcompPacket{FixedHeader: p.FixedHeader}
case packets.Subscribe:
pk = &packets.SubscribePacket{FixedHeader: p.FixedHeader}
case packets.Suback:
pk = &packets.SubackPacket{FixedHeader: p.FixedHeader}
case packets.Unsubscribe:
pk = &packets.UnsubscribePacket{FixedHeader: p.FixedHeader}
case packets.Unsuback:
pk = &packets.UnsubackPacket{FixedHeader: p.FixedHeader}
case packets.Pingreq:
pk = &packets.PingreqPacket{FixedHeader: p.FixedHeader}
case packets.Pingresp:
pk = &packets.PingrespPacket{FixedHeader: p.FixedHeader}
case packets.Disconnect:
pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
default:
return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
}

526
parser_test.go Normal file
View File

@@ -0,0 +1,526 @@
package mqtt
import (
"bufio"
"bytes"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/mochi-co/mqtt/packets"
)
func newBufioReader(c net.Conn) *bufio.Reader {
return bufio.NewReaderSize(c, 512)
}
func newBufioWriter(c net.Conn) *bufio.Writer {
return bufio.NewWriterSize(c, 512)
}
func TestNewParser(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
require.NotNil(t, p.R)
}
func BenchmarkNewParser(b *testing.B) {
conn := new(MockNetConn)
r, w := new(bufio.Reader), new(bufio.Writer)
for n := 0; n < b.N; n++ {
NewParser(conn, r, w)
}
}
func TestRefreshDeadline(t *testing.T) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
dl := p.Conn.(*MockNetConn).Deadline
p.RefreshDeadline(10)
require.NotEqual(t, dl, p.Conn.(*MockNetConn).Deadline)
}
func BenchmarkRefreshDeadline(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
for n := 0; n < b.N; n++ {
p.RefreshDeadline(10)
}
}
/*
type fixedHeaderTable struct {
rawBytes []byte
header packets.FixedHeader
packetError bool
flagError bool
}
var fixedHeaderExpected = []fixedHeaderTable{
{
rawBytes: []byte{packets.Connect << 4, 0x00},
header: packets.FixedHeader{packets.Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int
},
{
rawBytes: []byte{packets.Connack << 4, 0x00},
header: packets.FixedHeader{packets.Connack, false, 0, false, 0},
},
{
rawBytes: []byte{packets.Publish << 4, 0x00},
header: packets.FixedHeader{packets.Publish, false, 0, false, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 1<<1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 1, false, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 1<<1 | 1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 1, true, 0},
},
{
rawBytes: []byte{packets.Publish<<4 | 2<<1, 0x00},
header: packets.FixedHeader{packets.Publish, false, 2, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, false, 2, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Publish, true, 0, false, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Publish, true, 0, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 1, true, 0},
},
{
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Publish, true, 2, true, 0},
},
{
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Puback, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Pubrec, false, 0, false, 0},
},
{
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Pubrel, false, 1, false, 0},
},
{
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Pubcomp, false, 0, false, 0},
},
{
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Subscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Suback, false, 0, false, 0},
},
{
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Unsubscribe, false, 1, false, 0},
},
{
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Unsuback, false, 0, false, 0},
},
{
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Pingreq, false, 0, false, 0},
},
{
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Pingresp, false, 0, false, 0},
},
{
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Disconnect, false, 0, false, 0},
},
// remaining length
{
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Publish, false, 0, false, 10},
},
{
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Publish, false, 0, false, 512},
},
{
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Publish, false, 0, false, 978},
},
{
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Publish, false, 0, false, 20102},
},
{
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Publish, false, 0, false, 333333333},
packetError: true,
},
// Invalid flags for packet
{
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Connect, true, 0, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Connect, false, 1, false, 0},
flagError: true,
},
{
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Connect, false, 0, true, 0},
flagError: true,
},
}
*/
func TestReadFixedHeader(t *testing.T) {
conn := new(MockNetConn)
// Test null data.
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
fh := new(packets.FixedHeader)
err := p.ReadFixedHeader(fh)
require.Error(t, err)
// Test insufficient peeking.
fh = new(packets.FixedHeader)
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4}))
err = p.ReadFixedHeader(fh)
require.Error(t, err)
fh = new(packets.FixedHeader)
p.R = bufio.NewReader(bytes.NewReader([]byte{packets.Connect << 4, 0x00}))
err = p.ReadFixedHeader(fh)
require.NoError(t, err)
/*
// Test expected bytes.
for i, wanted := range fixedHeaderExpected {
fh := new(packets.FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
b := wanted.rawBytes
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(fh)
if wanted.packetError || wanted.flagError {
require.Error(t, err, "Expected error [i:%d] %v", i, wanted.rawBytes)
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Type, p.FixedHeader.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Dup, p.FixedHeader.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Qos, p.FixedHeader.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes)
require.Equal(t, wanted.header.Retain, p.FixedHeader.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes)
}
}
*/
}
func BenchmarkReadFixedHeader(b *testing.B) {
conn := new(MockNetConn)
fh := new(packets.FixedHeader)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
var rn bytes.Reader = *bytes.NewReader([]byte{packets.Connect << 4, 0x00})
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
err := p.ReadFixedHeader(fh)
if err != nil {
panic(err)
}
}
}
func TestRead(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader([]byte{
byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
}))
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
pko, err := p.Read()
require.NoError(t, err)
require.Equal(t, &packets.PublishPacket{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
}, pko)
/*
for code, pt := range expectedPackets {
for i, wanted := range pt {
if wanted.primary {
var fh packets.FixedHeader
b := wanted.rawBytes
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader(b))
err := p.ReadFixedHeader(&fh)
if wanted.failFirst != nil {
require.Error(t, err, "Expected error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
} else {
require.NoError(t, err, "Error reading fixedheader [i:%d] %s - %s", i, wanted.desc, Names[code])
}
pko, err := p.Read()
if wanted.expect != nil {
require.Error(t, err, "Expected error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
if err != nil {
require.Equal(t, err, wanted.expect, "Mismatched packet error [i:%d] %s - %s", i, wanted.desc, Names[code])
}
} else {
require.NoError(t, err, "Error reading packet [i:%d] %s - %s", i, wanted.desc, Names[code])
require.Equal(t, wanted.packet, pko, "Mismatched packet final [i:%d] %s - %s", i, wanted.desc, Names[code])
}
}
}
}
*/
}
func TestReadFail(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
p.R = bufio.NewReader(bytes.NewReader([]byte{
byte(packets.Publish << 4), 3, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/',
}))
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
_, err = p.Read()
require.Error(t, err)
}
func BenchmarkRead(b *testing.B) {
conn := new(MockNetConn)
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
pkb := []byte{
byte(packets.Publish << 4), 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
}
p.R = bufio.NewReader(bytes.NewReader(pkb))
var fh packets.FixedHeader
err := p.ReadFixedHeader(&fh)
if err != nil {
panic(err)
}
var rn bytes.Reader = *bytes.NewReader(pkb)
var rc bytes.Reader
for n := 0; n < b.N; n++ {
rc = rn
p.R.Reset(&rc)
p.R.Discard(2)
_, err := p.Read()
if err != nil {
panic(err)
}
}
}
// This is a super important test. It checks whether or not subsequent packets
// mutate each other. This happens when you use a single byte buffer for decoding
// multiple packets.
func TestReadPacketNoOverwrite(t *testing.T) {
pk1 := []byte{
byte(packets.Publish << 4), 12, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', // Payload
}
pk2 := []byte{
byte(packets.Publish << 4), 14, // Fixed header
0, 5, // Topic Name - LSB+MSB
'x', '/', 'y', '/', 'z', // Topic Name
'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
}
r, w := net.Pipe()
p := NewParser(r, newBufioReader(r), newBufioWriter(w))
go func() {
w.Write(pk1)
w.Write(pk2)
w.Close()
}()
var fh packets.FixedHeader
err := p.ReadFixedHeader(&fh)
require.NoError(t, err)
o1, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:])
require.NoError(t, err)
err = p.ReadFixedHeader(&fh)
require.NoError(t, err)
o2, err := p.Read()
require.NoError(t, err)
require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*packets.PublishPacket).Payload)
require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload, "o1 payload was mutated")
}
func TestReadPacketNil(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{0, 0}))
err := p.ReadFixedHeader(&fh)
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadOverflow(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 999999 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
func TestReadPacketReadAllFail(t *testing.T) {
conn := new(MockNetConn)
var fh packets.FixedHeader
p := NewParser(conn, newBufioReader(conn), newBufioWriter(conn))
// Check for un-specified packet.
// Create a ping request packet with a false fixedheader type code.
pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
pk.FixedHeader.Type = 99
p.R = bufio.NewReader(bytes.NewReader([]byte{byte(packets.Connect << 4), 0}))
err := p.ReadFixedHeader(&fh)
p.FixedHeader.Remaining = 1 // overflow buffer
_, err = p.Read()
require.Error(t, err, "Expected error reading packet")
}
// MockNetConn satisfies the net.Conn interface.
type MockNetConn struct {
ID string
Deadline time.Time
}
// Read reads bytes from the net io.reader.
func (m *MockNetConn) Read(b []byte) (n int, err error) {
return 0, nil
}
// Read writes bytes to the net io.writer.
func (m *MockNetConn) Write(b []byte) (n int, err error) {
return 0, nil
}
// Close closes the net.Conn connection.
func (m *MockNetConn) Close() error {
return nil
}
// LocalAddr returns the local address of the request.
func (m *MockNetConn) LocalAddr() net.Addr {
return new(MockNetAddr)
}
// RemoteAddr returns the remove address of the request.
func (m *MockNetConn) RemoteAddr() net.Addr {
return new(MockNetAddr)
}
// SetDeadline sets the request deadline.
func (m *MockNetConn) SetDeadline(t time.Time) error {
m.Deadline = t
return nil
}
// SetReadDeadline sets the read deadline.
func (m *MockNetConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline sets the write deadline.
func (m *MockNetConn) SetWriteDeadline(t time.Time) error {
return nil
}
// MockNetAddr satisfies net.Addr interface.
type MockNetAddr struct{}
// Network returns the network protocol.
func (m *MockNetAddr) Network() string {
return "tcp"
}
// String returns the network address.
func (m *MockNetAddr) String() string {
return "127.0.0.1"
}