mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-09-27 04:26:23 +08:00
Correctly validate WillProperties (#210)
Co-authored-by: sukvojte <sukvojte@gmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@ const (
|
||||
Pingresp // 13
|
||||
Disconnect // 14
|
||||
Auth // 15
|
||||
WillProperties byte = 99 // Special byte for validating Will Properties.
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -313,7 +314,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Properties).Encode(pk, pb, 0)
|
||||
(&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -322,7 +323,7 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error {
|
||||
if pk.Connect.WillFlag {
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
(&pk.Connect).WillProperties.Encode(pk, pb, 0)
|
||||
(&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0)
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -393,7 +394,7 @@ func (pk *Packet) ConnectDecode(buf []byte) error {
|
||||
|
||||
if pk.Connect.WillFlag { // [MQTT-3.1.2-7]
|
||||
if pk.ProtocolVersion == 5 {
|
||||
n, err := pk.Connect.WillProperties.Decode(pk.FixedHeader.Type, bytes.NewBuffer(buf[offset:]))
|
||||
n, err := pk.Connect.WillProperties.Decode(WillProperties, bytes.NewBuffer(buf[offset:]))
|
||||
if err != nil {
|
||||
return ErrMalformedWillProperties
|
||||
}
|
||||
@@ -496,7 +497,7 @@ func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -539,7 +540,7 @@ func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -608,7 +609,7 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.Payload))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -692,7 +693,7 @@ func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
}
|
||||
@@ -833,7 +834,7 @@ func (pk *Packet) SubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes))
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -890,7 +891,7 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -985,7 +986,7 @@ func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -1038,7 +1039,7 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error {
|
||||
|
||||
if pk.ProtocolVersion == 5 {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len()+xb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
}
|
||||
|
||||
@@ -1101,7 +1102,7 @@ func (pk *Packet) AuthEncode(buf *bytes.Buffer) error {
|
||||
nb.WriteByte(pk.ReasonCode)
|
||||
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
pk.Properties.Encode(pk, pb, nb.Len())
|
||||
pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len())
|
||||
nb.Write(pb.Bytes())
|
||||
|
||||
pk.FixedHeader.Remaining = nb.Len()
|
||||
|
@@ -42,11 +42,11 @@ const (
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1},
|
||||
PropContentType: {Publish: 1},
|
||||
PropResponseTopic: {Publish: 1},
|
||||
PropCorrelationData: {Publish: 1},
|
||||
PropPayloadFormat: {Publish: 1, WillProperties: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1, WillProperties: 1},
|
||||
PropContentType: {Publish: 1, WillProperties: 1},
|
||||
PropResponseTopic: {Publish: 1, WillProperties: 1},
|
||||
PropCorrelationData: {Publish: 1, WillProperties: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
@@ -54,7 +54,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {Connect: 1},
|
||||
PropWillDelayInterval: {WillProperties: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
@@ -64,7 +64,7 @@ var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1, WillProperties: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
@@ -194,14 +194,12 @@ func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
func (p *Properties) Encode(pkt byte, mods Mods, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
pkt := pk.FixedHeader.Type
|
||||
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
@@ -217,13 +215,13 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
@@ -277,7 +275,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
if mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
@@ -289,9 +287,9 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+len(b)+1) < mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
@@ -322,7 +320,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
if !mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
@@ -331,7 +329,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
|
||||
if mods.MaxSize == 0 || uint32(n+pb.Len()+1) < mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
@@ -361,7 +359,7 @@ func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
func (p *Properties) Decode(pkt byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -384,8 +382,8 @@ func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pk]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
|
||||
if _, ok := validPacketProperties[k][pkt]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pkt, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
|
@@ -202,14 +202,14 @@ func init() {
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
@@ -219,7 +219,7 @@ func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
@@ -232,7 +232,7 @@ func TestEncodePropertiesNil(t *testing.T) {
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
|
||||
pr.p.Encode(Reserved, Mods{}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user