diff --git a/packets/connack.go b/packets/connack.go index 8d9c31c..750c1b5 100644 --- a/packets/connack.go +++ b/packets/connack.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // ConnackPacket contains the values of an MQTT CONNACK packet. @@ -15,7 +14,17 @@ type ConnackPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *ConnackPacket) Encode(w io.Writer) error { +func (pk *ConnackPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.WriteByte(encodeBool(pk.SessionPresent)) + buf.WriteByte(pk.ReturnCode) + + return nil +} + +// Encode encodes and writes the packet data values to the buffer. +/*func (pk *ConnackPacket) Encode(w io.Writer) error { var body bytes.Buffer // Write flags to packet body. @@ -35,6 +44,8 @@ func (pk *ConnackPacket) Encode(w io.Writer) error { } +*/ + // Decode extracts the data values from the packet. func (pk *ConnackPacket) Decode(buf []byte) error { diff --git a/packets/connack_test.go b/packets/connack_test.go index a28d8e3..75fa100 100644 --- a/packets/connack_test.go +++ b/packets/connack_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" ) +/* func TestConnackEncode(t *testing.T) { require.Contains(t, expectedPackets, Connack) for i, wanted := range expectedPackets[Connack] { @@ -38,7 +39,6 @@ func TestConnackEncode(t *testing.T) { require.Equal(t, wanted.packet.(*ConnackPacket).SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) } } - func BenchmarkConnackEncode(b *testing.B) { pk := new(ConnackPacket) copier.Copy(pk, expectedPackets[Connack][0].packet.(*ConnackPacket)) @@ -49,6 +49,47 @@ func BenchmarkConnackEncode(b *testing.B) { } } +*/ + +func TestConnackEncode(t *testing.T) { + require.Contains(t, expectedPackets, Connack) + for i, wanted := range expectedPackets[Connack] { + + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(ConnackPacket) + copier.Copy(pk, wanted.packet.(*ConnackPacket)) + + require.Equal(t, Connack, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) + require.Equal(t, Connack, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Connack<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.(*ConnackPacket).ReturnCode, pk.ReturnCode, "Mismatched return code [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.(*ConnackPacket).SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkConnackEncode(b *testing.B) { + pk := new(ConnackPacket) + copier.Copy(pk, expectedPackets[Connack][0].packet.(*ConnackPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestConnackDecode(t *testing.T) { require.Contains(t, expectedPackets, Connack) for i, wanted := range expectedPackets[Connack] { diff --git a/packets/connect.go b/packets/connect.go index f1f35ba..2d21071 100644 --- a/packets/connect.go +++ b/packets/connect.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // ConnectPacket contains the values of an MQTT CONNECT packet. @@ -28,8 +27,7 @@ type ConnectPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *ConnectPacket) Encode(w io.Writer) error { - +func (pk *ConnectPacket) Encode(buf *bytes.Buffer) error { var body bytes.Buffer // Write flags to packet body. @@ -55,15 +53,11 @@ func (pk *ConnectPacket) Encode(w io.Writer) error { body.Write(encodeString(pk.Password)) } - // Set remaining length. pk.FixedHeader.Remaining = body.Len() + pk.FixedHeader.encode(buf) + buf.Write(body.Bytes()) - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err + return nil } // Decode extracts the data values from the packet. diff --git a/packets/connect_test.go b/packets/connect_test.go index 0b321b2..be24b3d 100644 --- a/packets/connect_test.go +++ b/packets/connect_test.go @@ -22,15 +22,12 @@ func TestConnectEncode(t *testing.T) { require.Equal(t, Connect, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Connect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + pk.Encode(buf) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) require.Equal(t, byte(Connect<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) ok, _ := pk.Validate() @@ -60,6 +57,16 @@ func TestConnectEncode(t *testing.T) { } } +func BenchmarkConnectEncode(b *testing.B) { + pk := new(ConnectPacket) + copier.Copy(pk, expectedPackets[Connect][0].packet.(*ConnectPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestConnectDecode(t *testing.T) { require.Contains(t, expectedPackets, Connect) for i, wanted := range expectedPackets[Connect] { diff --git a/packets/disconnect.go b/packets/disconnect.go index dc5a83f..67e40f4 100644 --- a/packets/disconnect.go +++ b/packets/disconnect.go @@ -1,7 +1,7 @@ package packets import ( - "io" + "bytes" ) // DisconnectPacket contains the values of an MQTT DISCONNECT packet. @@ -10,12 +10,9 @@ type DisconnectPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *DisconnectPacket) Encode(w io.Writer) error { - - out := pk.FixedHeader.encode() - _, err := out.WriteTo(w) - - return err +func (pk *DisconnectPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.encode(buf) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/disconnect_test.go b/packets/disconnect_test.go index 0fe6edb..ea46538 100644 --- a/packets/disconnect_test.go +++ b/packets/disconnect_test.go @@ -19,12 +19,23 @@ func TestDisconnectEncode(t *testing.T) { require.Equal(t, Disconnect, pk.Type, "Mismatched Packet Type [i:%d]", i) require.Equal(t, Disconnect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() - require.NoError(t, err, "Error writing buffer [i:%d]", i) - require.Equal(t, len(wanted.rawBytes), len(b.Bytes()), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, b.Bytes(), "Mismatched byte values [i:%d]", i) + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkDisconnectEncode(b *testing.B) { + pk := new(DisconnectPacket) + copier.Copy(pk, expectedPackets[Disconnect][0].packet.(*DisconnectPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) } } diff --git a/packets/fixedheader.go b/packets/fixedheader.go index f96a077..44b2af0 100644 --- a/packets/fixedheader.go +++ b/packets/fixedheader.go @@ -24,8 +24,14 @@ type FixedHeader struct { Remaining int } +// encode encodes the FixedHeader and returns a bytes buffer. +func (fh *FixedHeader) encode(buf *bytes.Buffer) { + buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) + encodeLength(buf, fh.Remaining) +} + // Encode encodes the FixedHeader and returns a bytes buffer. -func (fh *FixedHeader) encode() bytes.Buffer { +/*func (fh *FixedHeader) encode() bytes.Buffer { var encoded bytes.Buffer // Encode flags. @@ -37,6 +43,7 @@ func (fh *FixedHeader) encode() bytes.Buffer { return encoded } +*/ // decode extracts the specification bits from the header byte. func (fh *FixedHeader) decode(headerByte byte) error { @@ -77,7 +84,7 @@ func (fh *FixedHeader) decode(headerByte byte) error { } // encodeLength creates length bits for the header. -func encodeLength(length int) []byte { +/*func encodeLength(length int) []byte { encodedLength := make([]byte, 0, 8) for { digit := byte(length % 128) @@ -92,3 +99,19 @@ func encodeLength(length int) []byte { } return encodedLength } +*/ + +// encodeLength writes length bits for the header. +func encodeLength(buf *bytes.Buffer, length int) { + for { + digit := byte(length % 128) + length /= 128 + if length > 0 { + digit |= 0x80 + } + buf.WriteByte(digit) + if length == 0 { + break + } + } +} diff --git a/packets/fixedheader_test.go b/packets/fixedheader_test.go index f6da76a..5307692 100644 --- a/packets/fixedheader_test.go +++ b/packets/fixedheader_test.go @@ -1,7 +1,7 @@ package packets import ( - //"bytes" + "bytes" "math" "testing" @@ -149,15 +149,17 @@ var fixedHeaderExpected = []fixedHeaderTable{ func TestFixedHeaderEncode(t *testing.T) { for i, wanted := range fixedHeaderExpected { - b := wanted.header.encode() - require.Equal(t, len(wanted.rawBytes), len(b.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes) - require.EqualValues(t, wanted.rawBytes, b.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes) + buf := new(bytes.Buffer) + wanted.header.encode(buf) + require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes) + require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes) } } func BenchmarkFixedHeaderEncode(b *testing.B) { + buf := new(bytes.Buffer) for n := 0; n < b.N; n++ { - fixedHeaderExpected[0].header.encode() + fixedHeaderExpected[0].header.encode(buf) } } @@ -203,12 +205,15 @@ func TestEncodeLength(t *testing.T) { } for i, wanted := range tt { - require.Equal(t, wanted.want, encodeLength(wanted.have), "Returned bytes should match length [i:%d] %s", i, wanted.have) + buf := new(bytes.Buffer) + encodeLength(buf, wanted.have) + require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have) } } func BenchmarkEncodeLength(b *testing.B) { + buf := new(bytes.Buffer) for n := 0; n < b.N; n++ { - encodeLength(120) + encodeLength(buf, 120) } } diff --git a/packets/parser.go b/packets/parser.go index 4770ac8..6ab96a0 100644 --- a/packets/parser.go +++ b/packets/parser.go @@ -2,6 +2,7 @@ package packets import ( "bufio" + "bytes" "encoding/binary" "errors" "io" @@ -13,8 +14,8 @@ import ( // Packet is the base interface that all MQTT packets must implement. type Packet interface { - // Encode encodes a packet into a byte array and writes it to a writer. - Encode(io.Writer) error + // Encode encodes a packet into a byte buffer. + Encode(*bytes.Buffer) error // Decode decodes a byte array into a packet struct. Decode([]byte) error diff --git a/packets/pingreq.go b/packets/pingreq.go index 10e3707..b25e8ae 100644 --- a/packets/pingreq.go +++ b/packets/pingreq.go @@ -1,7 +1,7 @@ package packets import ( - "io" + "bytes" ) // PingreqPacket contains the values of an MQTT PINGREQ packet. @@ -10,12 +10,9 @@ type PingreqPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PingreqPacket) Encode(w io.Writer) error { - - out := pk.FixedHeader.encode() - _, err := out.WriteTo(w) - - return err +func (pk *PingreqPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.encode(buf) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/pingreq_test.go b/packets/pingreq_test.go index 113c142..8e185f6 100644 --- a/packets/pingreq_test.go +++ b/packets/pingreq_test.go @@ -21,12 +21,23 @@ func TestPingreqEncode(t *testing.T) { require.Equal(t, Pingreq, pk.Type, "Mismatched Packet Type [i:%d]", i) require.Equal(t, Pingreq, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() - require.NoError(t, err, "Error writing buffer [i:%d]", i) - require.Equal(t, len(wanted.rawBytes), len(b.Bytes()), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, b.Bytes(), "Mismatched byte values [i:%d]", i) + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkPingreqEncode(b *testing.B) { + pk := new(PingreqPacket) + copier.Copy(pk, expectedPackets[Pingreq][0].packet.(*PingreqPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) } } diff --git a/packets/pingresp.go b/packets/pingresp.go index 5c44433..3865054 100644 --- a/packets/pingresp.go +++ b/packets/pingresp.go @@ -1,7 +1,7 @@ package packets import ( - "io" + "bytes" ) // PingrespPacket contains the values of an MQTT PINGRESP packet. @@ -10,12 +10,9 @@ type PingrespPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PingrespPacket) Encode(w io.Writer) error { - - out := pk.FixedHeader.encode() - _, err := out.WriteTo(w) - - return err +func (pk *PingrespPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.encode(buf) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/pingresp_test.go b/packets/pingresp_test.go index bf3b06b..1ef3759 100644 --- a/packets/pingresp_test.go +++ b/packets/pingresp_test.go @@ -20,12 +20,23 @@ func TestPingrespEncode(t *testing.T) { require.Equal(t, Pingresp, pk.Type, "Mismatched Packet Type [i:%d]", i) require.Equal(t, Pingresp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() - require.NoError(t, err, "Error writing buffer [i:%d]", i) - require.Equal(t, len(wanted.rawBytes), len(b.Bytes()), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, b.Bytes(), "Mismatched byte values [i:%d]", i) + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkPingrespEncode(b *testing.B) { + pk := new(PingrespPacket) + copier.Copy(pk, expectedPackets[Pingresp][0].packet.(*PingrespPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) } } diff --git a/packets/puback.go b/packets/puback.go index 891143a..367b515 100644 --- a/packets/puback.go +++ b/packets/puback.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // PubackPacket contains the values of an MQTT PUBACK packet. @@ -14,20 +13,11 @@ type PubackPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PubackPacket) Encode(w io.Writer) error { - - var body bytes.Buffer - - // Add the Packet ID. - body.Write(encodeUint16(pk.PacketID)) - pk.Remaining = 2 - - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err +func (pk *PubackPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/puback_test.go b/packets/puback_test.go index d400405..b28ca82 100644 --- a/packets/puback_test.go +++ b/packets/puback_test.go @@ -22,21 +22,29 @@ func TestPubackEncode(t *testing.T) { require.Equal(t, Puback, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Puback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) require.Equal(t, byte(Puback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PubackPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) } } +func BenchmarkPubackEncode(b *testing.B) { + pk := new(PubackPacket) + copier.Copy(pk, expectedPackets[Puback][0].packet.(*PubackPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestPubackDecode(t *testing.T) { require.Contains(t, expectedPackets, Puback) for i, wanted := range expectedPackets[Puback] { diff --git a/packets/pubcomp.go b/packets/pubcomp.go index 5b549f8..9692c70 100644 --- a/packets/pubcomp.go +++ b/packets/pubcomp.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // PubcompPacket contains the values of an MQTT PUBCOMP packet. @@ -14,32 +13,20 @@ type PubcompPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PubcompPacket) Encode(w io.Writer) error { - - var body bytes.Buffer - - // Add the Packet ID. - body.Write(encodeUint16(pk.PacketID)) - pk.Remaining = 2 - - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err +func (pk *PubcompPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil } // Decode extracts the data values from the packet. func (pk *PubcompPacket) Decode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) if err != nil { return errors.New(ErrMalformedPacketID) } - return nil } diff --git a/packets/pubcomp_test.go b/packets/pubcomp_test.go index fc406d5..25048aa 100644 --- a/packets/pubcomp_test.go +++ b/packets/pubcomp_test.go @@ -22,21 +22,29 @@ func TestPubcompEncode(t *testing.T) { require.Equal(t, Pubcomp, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Pubcomp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) require.Equal(t, byte(Pubcomp<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PubcompPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) } } +func BenchmarkPubcompEncode(b *testing.B) { + pk := new(PubcompPacket) + copier.Copy(pk, expectedPackets[Pubcomp][0].packet.(*PubcompPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestPubcompDecode(t *testing.T) { require.Contains(t, expectedPackets, Pubcomp) for i, wanted := range expectedPackets[Pubcomp] { diff --git a/packets/publish.go b/packets/publish.go index 77b5654..257b660 100644 --- a/packets/publish.go +++ b/packets/publish.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // PublishPacket contains the values of an MQTT PUBLISH packet. @@ -16,12 +15,10 @@ type PublishPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PublishPacket) Encode(w io.Writer) error { - +func (pk *PublishPacket) Encode(buf *bytes.Buffer) error { var body bytes.Buffer - // Write topic name. - body.Write(encodeString(pk.TopicName)) + body.Write(encodeString(pk.TopicName)) // Write topic name. // Add PacketID if QOS is set. // [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. @@ -37,15 +34,11 @@ func (pk *PublishPacket) Encode(w io.Writer) error { // Set remaining length. pk.FixedHeader.Remaining = body.Len() + len(pk.Payload) + pk.FixedHeader.encode(buf) + buf.Write(body.Bytes()) + buf.Write(pk.Payload) - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - out.Write(pk.Payload) - _, err := out.WriteTo(w) - - return err - + return nil } // Decode extracts the data values from the packet. diff --git a/packets/publish_test.go b/packets/publish_test.go index b5ee388..9b3db84 100644 --- a/packets/publish_test.go +++ b/packets/publish_test.go @@ -23,12 +23,13 @@ func TestPublishEncode(t *testing.T) { require.Equal(t, Publish, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Publish, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + encoded := buf.Bytes() + if wanted.expect != nil { require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) } else { - encoded := b.Bytes() // If actualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc). if len(wanted.actualBytes) > 0 { @@ -47,11 +48,21 @@ func TestPublishEncode(t *testing.T) { require.Equal(t, wanted.packet.(*PublishPacket).FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PublishPacket).FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PublishPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) } } } +func BenchmarkPublishEncode(b *testing.B) { + pk := new(PublishPacket) + copier.Copy(pk, expectedPackets[Publish][0].packet.(*PublishPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestPublishDecode(t *testing.T) { require.Contains(t, expectedPackets, Publish) for i, wanted := range expectedPackets[Publish] { diff --git a/packets/pubrec.go b/packets/pubrec.go index 5179bfe..2959488 100644 --- a/packets/pubrec.go +++ b/packets/pubrec.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // PubrecPacket contains the values of an MQTT PUBREC packet. @@ -14,20 +13,11 @@ type PubrecPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PubrecPacket) Encode(w io.Writer) error { - - var body bytes.Buffer - - // Add the Packet ID. - body.Write(encodeUint16(pk.PacketID)) - pk.Remaining = 2 - - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err +func (pk *PubrecPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/pubrec_test.go b/packets/pubrec_test.go index cc66312..123a9a9 100644 --- a/packets/pubrec_test.go +++ b/packets/pubrec_test.go @@ -23,21 +23,28 @@ func TestPubrecEncode(t *testing.T) { require.Equal(t, Pubrec, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Pubrec, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) require.Equal(t, byte(Pubrec<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PubrecPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) } +} +func BenchmarkPubrecEncode(b *testing.B) { + pk := new(PubrecPacket) + copier.Copy(pk, expectedPackets[Pubrec][0].packet.(*PubrecPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } } func TestPubrecDecode(t *testing.T) { diff --git a/packets/pubrel.go b/packets/pubrel.go index d358264..e6488f8 100644 --- a/packets/pubrel.go +++ b/packets/pubrel.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // PubrelPacket contains the values of an MQTT PUBREL packet. @@ -14,20 +13,11 @@ type PubrelPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *PubrelPacket) Encode(w io.Writer) error { - - var body bytes.Buffer - - // Add the Packet ID. - body.Write(encodeUint16(pk.PacketID)) - pk.Remaining = 2 - - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err +func (pk *PubrelPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/pubrel_test.go b/packets/pubrel_test.go index 850722e..dd374ac 100644 --- a/packets/pubrel_test.go +++ b/packets/pubrel_test.go @@ -23,15 +23,13 @@ func TestPubrelEncode(t *testing.T) { require.Equal(t, Pubrel, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Pubrel, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) require.Equal(t, byte(Pubrel<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*PubrelPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) @@ -40,6 +38,16 @@ func TestPubrelEncode(t *testing.T) { } +func BenchmarkPubrelEncode(b *testing.B) { + pk := new(PubrelPacket) + copier.Copy(pk, expectedPackets[Pubrel][0].packet.(*PubrelPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestPubrelDecode(t *testing.T) { require.Contains(t, expectedPackets, Pubrel) for i, wanted := range expectedPackets[Pubrel] { @@ -63,7 +71,7 @@ func TestPubrelDecode(t *testing.T) { } } -func BenchmarkPubrelkDecode(b *testing.B) { +func BenchmarkPubrelDecode(b *testing.B) { pk := newPacket(Pubrel).(*PubrelPacket) pk.FixedHeader.decode(expectedPackets[Pubrel][0].rawBytes[0]) diff --git a/packets/suback.go b/packets/suback.go index ac99ccc..f481358 100644 --- a/packets/suback.go +++ b/packets/suback.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // SubackPacket contains the values of an MQTT SUBACK packet. @@ -15,25 +14,32 @@ type SubackPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *SubackPacket) Encode(w io.Writer) error { +func (pk *SubackPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.encode(buf) - var body bytes.Buffer + bodyLen := buf.Len() - // Encode Packet ID. - body.Write(encodeUint16(pk.PacketID)) + buf.Write(encodeUint16(pk.PacketID)) // Encode Packet ID. + buf.Write(pk.ReturnCodes) // Encode granted QOS flags. - // Encode granted QOS flags. - body.Write(pk.ReturnCodes) + pk.FixedHeader.Remaining = buf.Len() - bodyLen // Set length. - // Set length. - pk.FixedHeader.Remaining = body.Len() + return nil + /* - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) + var body bytes.Buffer - return err + body.Write(encodeUint16(pk.PacketID)) + + body.Write(pk.ReturnCodes) + + // Write header and packet to output. + out := pk.FixedHeader.encode() + out.Write(body.Bytes()) + _, err := out.WriteTo(w) + + return err + */ } // Decode extracts the data values from the packet. diff --git a/packets/suback_test.go b/packets/suback_test.go index 0ed1293..338ff57 100644 --- a/packets/suback_test.go +++ b/packets/suback_test.go @@ -23,10 +23,10 @@ func TestSubackEncode(t *testing.T) { require.Equal(t, Suback, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Suback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) if wanted.meta != nil { @@ -35,7 +35,6 @@ func TestSubackEncode(t *testing.T) { require.Equal(t, byte(Suback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) } - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*SubackPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) @@ -43,6 +42,16 @@ func TestSubackEncode(t *testing.T) { } } +func BenchmarkSubackEncode(b *testing.B) { + pk := new(SubackPacket) + copier.Copy(pk, expectedPackets[Suback][0].packet.(*SubackPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestSubackDecode(t *testing.T) { require.Contains(t, expectedPackets, Suback) for i, wanted := range expectedPackets[Suback] { diff --git a/packets/subscribe.go b/packets/subscribe.go index bf68ba9..5758a06 100644 --- a/packets/subscribe.go +++ b/packets/subscribe.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // SubscribePacket contains the values of an MQTT SUBSCRIBE packet. @@ -16,7 +15,7 @@ type SubscribePacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *SubscribePacket) Encode(w io.Writer) error { +func (pk *SubscribePacket) Encode(buf *bytes.Buffer) error { var body bytes.Buffer @@ -36,13 +35,38 @@ func (pk *SubscribePacket) Encode(w io.Writer) error { // Set length. pk.FixedHeader.Remaining = body.Len() + pk.FixedHeader.encode(buf) + buf.Write(body.Bytes()) - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) + /* + var body bytes.Buffer - return err + // Add the Packet ID. + // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.PacketID == 0 { + return errors.New(ErrMissingPacketID) + } + + body.Write(encodeUint16(pk.PacketID)) + + // Add all provided topic names and associated QOS flags. + for i, topic := range pk.Topics { + body.Write(encodeString(topic)) + body.WriteByte(pk.Qoss[i]) + } + + // Set length. + pk.FixedHeader.Remaining = body.Len() + + // Write header and packet to output. + out := pk.FixedHeader.encode() + out.Write(body.Bytes()) + _, err := out.WriteTo(w) + + return err + */ + + return nil } diff --git a/packets/subscribe_test.go b/packets/subscribe_test.go index 58e26a0..76ee304 100644 --- a/packets/subscribe_test.go +++ b/packets/subscribe_test.go @@ -22,18 +22,14 @@ func TestSubscribeEncode(t *testing.T) { require.Equal(t, Subscribe, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Subscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + encoded := buf.Bytes() + if wanted.expect != nil { - require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) - } else { - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) - - encoded := b.Bytes() - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) if wanted.meta != nil { require.Equal(t, byte(Subscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) @@ -42,11 +38,21 @@ func TestSubscribeEncode(t *testing.T) { } require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.(*SubscribePacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*SubscribePacket).Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*SubscribePacket).Qoss, pk.Qoss, "Mismatched Qoss slice [i:%d] %s", i, wanted.desc) } + + } +} + +func BenchmarkSubscribeEncode(b *testing.B) { + pk := new(SubscribePacket) + copier.Copy(pk, expectedPackets[Subscribe][0].packet.(*SubscribePacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) } } diff --git a/packets/unsuback.go b/packets/unsuback.go index 8e81e2f..9a37da4 100644 --- a/packets/unsuback.go +++ b/packets/unsuback.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // UnsubackPacket contains the values of an MQTT UNSUBACK packet. @@ -14,21 +13,11 @@ type UnsubackPacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *UnsubackPacket) Encode(w io.Writer) error { - - var body bytes.Buffer - - // Add the Packet ID. - body.Write(encodeUint16(pk.PacketID)) - pk.Remaining = 2 - - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) - - return err - +func (pk *UnsubackPacket) Encode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil } // Decode extracts the data values from the packet. diff --git a/packets/unsuback_test.go b/packets/unsuback_test.go index 9aa0563..e3d99f6 100644 --- a/packets/unsuback_test.go +++ b/packets/unsuback_test.go @@ -22,10 +22,10 @@ func TestUnsubackEncode(t *testing.T) { require.Equal(t, Unsuback, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Unsuback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) - - encoded := b.Bytes() + buf := new(bytes.Buffer) + err := pk.Encode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) if wanted.meta != nil { @@ -34,13 +34,22 @@ func TestUnsubackEncode(t *testing.T) { require.Equal(t, byte(Unsuback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) } - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) require.Equal(t, wanted.packet.(*UnsubackPacket).PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) } } +func BenchmarkUnsubackEncode(b *testing.B) { + pk := new(UnsubackPacket) + copier.Copy(pk, expectedPackets[Unsuback][0].packet.(*UnsubackPacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestUnsubackDecode(t *testing.T) { require.Contains(t, expectedPackets, Unsuback) for i, wanted := range expectedPackets[Unsuback] { diff --git a/packets/unsubscribe.go b/packets/unsubscribe.go index 1ec06a3..b9642fa 100644 --- a/packets/unsubscribe.go +++ b/packets/unsubscribe.go @@ -3,7 +3,6 @@ package packets import ( "bytes" "errors" - "io" ) // UnsubscribePacket contains the values of an MQTT UNSUBSCRIBE packet. @@ -15,8 +14,7 @@ type UnsubscribePacket struct { } // Encode encodes and writes the packet data values to the buffer. -func (pk *UnsubscribePacket) Encode(w io.Writer) error { - +func (pk *UnsubscribePacket) Encode(buf *bytes.Buffer) error { var body bytes.Buffer // Add the Packet ID. @@ -34,14 +32,37 @@ func (pk *UnsubscribePacket) Encode(w io.Writer) error { // Set length. pk.FixedHeader.Remaining = body.Len() + pk.FixedHeader.encode(buf) + buf.Write(body.Bytes()) - // Write header and packet to output. - out := pk.FixedHeader.encode() - out.Write(body.Bytes()) - _, err := out.WriteTo(w) + return nil - return err + /* + var body bytes.Buffer + // Add the Packet ID. + // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.PacketID == 0 { + return errors.New(ErrMissingPacketID) + } + + body.Write(encodeUint16(pk.PacketID)) + + // Add all provided topic names and associated QOS flags. + for _, topic := range pk.Topics { + body.Write(encodeString(topic)) + } + + // Set length. + pk.FixedHeader.Remaining = body.Len() + + // Write header and packet to output. + out := pk.FixedHeader.encode() + out.Write(body.Bytes()) + _, err := out.WriteTo(w) + + return err + */ } // Decode extracts the data values from the packet. diff --git a/packets/unsubscribe_test.go b/packets/unsubscribe_test.go index f46aa72..4cfb715 100644 --- a/packets/unsubscribe_test.go +++ b/packets/unsubscribe_test.go @@ -22,18 +22,14 @@ func TestUnsubscribeEncode(t *testing.T) { require.Equal(t, Unsubscribe, pk.Type, "Mismatched Packet Type [i:%d] %s", i, wanted.desc) require.Equal(t, Unsubscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - var b bytes.Buffer - err := pk.Encode(&b) + buf := new(bytes.Buffer) + err := pk.Encode(buf) + encoded := buf.Bytes() + if wanted.expect != nil { - require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) - } else { - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) - - encoded := b.Bytes() - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) if wanted.meta != nil { require.Equal(t, byte(Unsubscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) @@ -50,6 +46,16 @@ func TestUnsubscribeEncode(t *testing.T) { } } +func BenchmarkUnsubscribeEncode(b *testing.B) { + pk := new(UnsubscribePacket) + copier.Copy(pk, expectedPackets[Unsubscribe][0].packet.(*UnsubscribePacket)) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.Encode(buf) + } +} + func TestUnsubscribeDecode(t *testing.T) { require.Contains(t, expectedPackets, Unsubscribe) for i, wanted := range expectedPackets[Unsubscribe] {