Correctly validate WillProperties (#210)

Co-authored-by: sukvojte <sukvojte@gmail.com>
This commit is contained in:
JB
2023-05-04 22:37:23 +01:00
committed by GitHub
parent 4b49652a8c
commit 1ec880844d
3 changed files with 53 additions and 54 deletions

View File

@@ -16,22 +16,23 @@ import (
// All of the valid packet types and their packet identifier. // All of the valid packet types and their packet identifier.
const ( const (
Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets. Reserved byte = iota // 0 - we use this in packet tests to indicate special-test or all packets.
Connect // 1 Connect // 1
Connack // 2 Connack // 2
Publish // 3 Publish // 3
Puback // 4 Puback // 4
Pubrec // 5 Pubrec // 5
Pubrel // 6 Pubrel // 6
Pubcomp // 7 Pubcomp // 7
Subscribe // 8 Subscribe // 8
Suback // 9 Suback // 9
Unsubscribe // 10 Unsubscribe // 10
Unsuback // 11 Unsuback // 11
Pingreq // 12 Pingreq // 12
Pingresp // 13 Pingresp // 13
Disconnect // 14 Disconnect // 14
Auth // 15 Auth // 15
WillProperties byte = 99 // Special byte for validating Will Properties.
) )
var ( var (
@@ -313,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
(&pk.Properties).Encode(pk, pb, 0) (&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -322,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
if pk.Connect.WillFlag { if pk.Connect.WillFlag {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
(&pk.Connect).WillProperties.Encode(pk, pb, 0) (&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -393,7 +394,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
if pk.Connect.WillFlag { // [MQTT-3.1.2-7] if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:])) n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
if err != nil { if err != nil {
return ErrMalformedWillProperties return ErrMalformedWillProperties
} }
@@ -496,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -539,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode) nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -608,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload)) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -692,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 { if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
nb.WriteByte(pk.ReasonCode) nb.WriteByte(pk.ReasonCode)
} }
@@ -833,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes)) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -890,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -985,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -1038,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
if pk.ProtocolVersion == 5 { if pk.ProtocolVersion == 5 {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
} }
@@ -1101,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
nb.WriteByte(pk.ReasonCode) nb.WriteByte(pk.ReasonCode)
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
pk.Properties.Encode(pk, pb, nb.Len()) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
nb.Write(pb.Bytes()) nb.Write(pb.Bytes())
pk.FixedHeader.Remaining = nb.Len() pk.FixedHeader.Remaining = nb.Len()

View File

@@ -42,11 +42,11 @@ const (
// validPacketProperties indicates which properties are valid for which packet types. // validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{ var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1}, PropPayloadFormat: {Publish: 1, WillProperties: 1},
PropMessageExpiryInterval: {Publish: 1}, PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
PropContentType: {Publish: 1}, PropContentType: {Publish: 1, WillProperties: 1},
PropResponseTopic: {Publish: 1}, PropResponseTopic: {Publish: 1, WillProperties: 1},
PropCorrelationData: {Publish: 1}, PropCorrelationData: {Publish: 1, WillProperties: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1}, PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1}, PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1}, PropAssignedClientID: {Connack: 1},
@@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1}, PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1}, PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1}, PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1}, PropWillDelayInterval: {WillProperties: 1},
PropRequestResponseInfo: {Connect: 1}, PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1}, PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1}, PropServerReference: {Connack: 1, Disconnect: 1},
@@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
PropTopicAlias: {Publish: 1}, PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1}, PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1}, PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1}, PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1}, PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1}, PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1}, PropSubIDAvailable: {Connack: 1},
@@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
} }
// Encode encodes properties into a bytes buffer. // Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) { func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
if p == nil { if p == nil {
return return
} }
var buf bytes.Buffer var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag { if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat) buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat) buf.WriteByte(p.PayloadFormat)
@@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19] buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
} }
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14] if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28] p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic) buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13] buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
} }
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28] if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData) buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData)) buf.Write(encodeBytes(p.CorrelationData))
} }
@@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RequestResponseInfo) buf.WriteByte(p.RequestResponseInfo)
} }
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28] if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo) buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo)) buf.Write(encodeString(p.ResponseInfo))
} }
@@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2] // [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2] // [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" { if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString) b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize { if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
buf.WriteByte(PropReasonString) buf.WriteByte(PropReasonString)
buf.Write(b) buf.Write(b)
} }
@@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
buf.WriteByte(p.RetainAvailable) buf.WriteByte(p.RetainAvailable)
} }
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) { if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{}) pb := bytes.NewBuffer([]byte{})
for _, v := range p.User { for _, v := range p.User {
pb.WriteByte(PropUser) pb.WriteByte(PropUser)
@@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
} }
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3] // [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3] // [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize { if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
buf.Write(pb.Bytes()) buf.Write(pb.Bytes())
} }
} }
@@ -361,7 +359,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
} }
// Decode decodes property bytes into a properties struct. // Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) { func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
if p == nil { if p == nil {
return 0, nil return 0, nil
} }
@@ -384,8 +382,8 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
return n + bu, err return n + bu, err
} }
if _, ok := validPacketProperties[k][pk]; !ok { if _, ok := validPacketProperties[k][pkt]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty) return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
} }
switch k { switch k {

View File

@@ -202,14 +202,14 @@ func init() {
func TestEncodeProperties(t *testing.T) { func TestEncodeProperties(t *testing.T) {
props := propertiesStruct props := propertiesStruct
b := bytes.NewBuffer([]byte{}) b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes()) require.Equal(t, propertiesBytes, b.Bytes())
} }
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct props := propertiesStruct
b := bytes.NewBuffer([]byte{}) b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0) props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes()) require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6})) require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5})) require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
@@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) { func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct props := propertiesStruct
b := bytes.NewBuffer([]byte{}) b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0) props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes()) require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5}) require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4}) require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
@@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
pr := tmp{} pr := tmp{}
b := bytes.NewBuffer([]byte{}) b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0) pr.p.Encode(Reserved, Mods{}, b, 0)
require.Equal(t, []byte{}, b.Bytes()) require.Equal(t, []byte{}, b.Bytes())
} }
@@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero. // [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties) props := new(Properties)
b := bytes.NewBuffer([]byte{}) b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0) props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes()) require.Equal(t, []byte{0x00}, b.Bytes())
} }
@@ -250,7 +250,7 @@ func TestDecodeProperties(t *testing.T) {
props := new(Properties) props := new(Properties)
n, err := props.Decode(Reserved, b) n, err := props.Decode(Reserved, b)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 172 + 2, n) require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props) require.EqualValues(t, propertiesStruct, *props)
} }