Files
mochi-mqtt/packets/packets_test.go

506 lines
16 KiB
Go

// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"testing"
"github.com/jinzhu/copier"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var packetList = []byte{
Connect,
Connack,
Publish,
Puback,
Pubrec,
Pubrel,
Pubcomp,
Subscribe,
Suback,
Unsubscribe,
Unsuback,
Pingreq,
Pingresp,
Disconnect,
Auth,
}
var pkTable = []TPacketCase{
TPacketData[Connect].Get(TConnectMqtt311),
TPacketData[Connect].Get(TConnectMqtt5),
TPacketData[Connect].Get(TConnectUserPassLWT),
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
TPacketData[Connack].Get(TConnackAcceptedNoSession),
TPacketData[Publish].Get(TPublishBasic),
TPacketData[Publish].Get(TPublishMqtt5),
TPacketData[Puback].Get(TPuback),
TPacketData[Pubrec].Get(TPubrec),
TPacketData[Pubrel].Get(TPubrel),
TPacketData[Pubcomp].Get(TPubcomp),
TPacketData[Subscribe].Get(TSubscribe),
TPacketData[Subscribe].Get(TSubscribeMqtt5),
TPacketData[Suback].Get(TSuback),
TPacketData[Unsubscribe].Get(TUnsubscribe),
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
TPacketData[Pingreq].Get(TPingreq),
TPacketData[Pingresp].Get(TPingresp),
TPacketData[Disconnect].Get(TDisconnect),
TPacketData[Disconnect].Get(TDisconnectMqtt5),
}
func TestNewPackets(t *testing.T) {
s := NewPackets()
require.NotNil(t, s.internal)
}
func TestPacketsAdd(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{})
require.Contains(t, s.internal, "cl1")
}
func TestPacketsGet(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
pk, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, "a1", pk.TopicName)
}
func TestPacketsGetAll(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
s.Add("cl3", Packet{TopicName: "a3"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestPacketsLen(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSPacketsDelete(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
p := &Subscription{
Qos: 2,
NoLocal: true,
RetainAsPublished: true,
RetainHandling: 2,
}
x := new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
p = &Subscription{
Qos: 1,
NoLocal: false,
RetainAsPublished: false,
RetainHandling: 1,
}
x = new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
}
func TestPacketEncode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !encodeTestOK(wanted) {
return
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
buf := new(bytes.Buffer)
var err error
switch pkt {
case Connect:
err = pk.ConnectEncode(buf)
case Connack:
err = pk.ConnackEncode(buf)
case Publish:
err = pk.PublishEncode(buf)
case Puback:
err = pk.PubackEncode(buf)
case Pubrec:
err = pk.PubrecEncode(buf)
case Pubrel:
err = pk.PubrelEncode(buf)
case Pubcomp:
err = pk.PubcompEncode(buf)
case Subscribe:
err = pk.SubscribeEncode(buf)
case Suback:
err = pk.SubackEncode(buf)
case Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case Unsuback:
err = pk.UnsubackEncode(buf)
case Pingreq:
err = pk.PingreqEncode(buf)
case Pingresp:
err = pk.PingrespEncode(buf)
case Disconnect:
err = pk.DisconnectEncode(buf)
case Auth:
err = pk.AuthEncode(buf)
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
encoded := buf.Bytes()
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
if len(wanted.ActualBytes) > 0 {
wanted.RawBytes = wanted.ActualBytes
}
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
})
}
}
}
func TestPacketDecode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !decodeTestOK(wanted) {
return
}
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
}
buf := wanted.RawBytes[2:]
var err error
switch pkt {
case Connect:
err = pk.ConnectDecode(buf)
case Connack:
err = pk.ConnackDecode(buf)
case Publish:
err = pk.PublishDecode(buf)
case Puback:
err = pk.PubackDecode(buf)
case Pubrec:
err = pk.PubrecDecode(buf)
case Pubrel:
err = pk.PubrelDecode(buf)
case Pubcomp:
err = pk.PubcompDecode(buf)
case Subscribe:
err = pk.SubscribeDecode(buf)
case Suback:
err = pk.SubackDecode(buf)
case Unsubscribe:
err = pk.UnsubscribeDecode(buf)
case Unsuback:
err = pk.UnsubackDecode(buf)
case Pingreq:
err = pk.PingreqDecode(buf)
case Pingresp:
err = pk.PingrespDecode(buf)
case Disconnect:
err = pk.DisconnectDecode(buf)
case Auth:
err = pk.AuthDecode(buf)
}
if wanted.FailFirst != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
if pkt == Connect {
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
// against it unless it's a connect packet.
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
}
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
})
}
}
}
func TestValidate(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if wanted.Group == "validate" || wanted.Primary {
pk := wanted.Packet
var err error
switch pkt {
case Connect:
err = pk.ConnectValidate()
case Publish:
err = pk.PublishValidate(1024)
case Subscribe:
err = pk.SubscribeValidate()
case Unsubscribe:
err = pk.UnsubscribeValidate()
case Auth:
err = pk.AuthValidate()
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
}
}
})
}
}
}
func TestAckValidatePubrec(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoMatchingSubscribers.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicNameInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrPayloadFormatInvalid.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubrel(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubcomp(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateSuback(t *testing.T) {
for _, b := range []byte{
CodeGrantedQos0.Code,
CodeGrantedQos1.Code,
CodeGrantedQos2.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrSharedSubscriptionsNotSupported.Code,
ErrSubscriptionIdentifiersNotSupported.Code,
ErrWildcardSubscriptionsNotSupported.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateUnsuback(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoSubscriptionExisted.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestReasonCodeValidMisc(t *testing.T) {
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
require.True(t, pk.ReasonCodeValid())
}
func TestCopy(t *testing.T) {
for _, tt := range pkTable {
pkc := tt.Packet.Copy(true)
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}
func TestMergeSubscription(t *testing.T) {
sub := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 0,
RetainAsPublished: false,
NoLocal: false,
Identifier: 1,
}
sub2 := Subscription{
Filter: "a/b/d",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 2,
}
expect := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 1,
Identifiers: map[string]int{
"a/b/c": 1,
"a/b/d": 2,
},
}
require.Equal(t, expect, sub.Merge(sub2))
}