diff --git a/server/internal/packets/codec.go b/server/internal/packets/codec.go index 7ab4cdc..1efea61 100644 --- a/server/internal/packets/codec.go +++ b/server/internal/packets/codec.go @@ -28,6 +28,10 @@ func decodeString(buf []byte, offset int) (string, int, error) { return "", 0, err } + if !validUTF8(b) { + return "", 0, ErrOffsetStrInvalidUTF8 + } + return bytesToString(b), n, nil } @@ -39,12 +43,10 @@ func decodeBytes(buf []byte, offset int) ([]byte, int, error) { } if next+int(length) > len(buf) { - return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange + return make([]byte, 0, 0), 0, ErrOffsetBytesOutOfRange } - if !validUTF8(buf[next : next+int(length)]) { - return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8 - } + // Note: there is no validUTF8() test for []byte payloads return buf[next : next+int(length)], next + int(length), nil } diff --git a/server/internal/packets/codec_test.go b/server/internal/packets/codec_test.go index 2cd0438..7778bc0 100644 --- a/server/internal/packets/codec_test.go +++ b/server/internal/packets/codec_test.go @@ -1,6 +1,8 @@ package packets import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -19,15 +21,16 @@ func BenchmarkBytesToString(b *testing.B) { func TestDecodeString(t *testing.T) { expect := []struct { + name string rawBytes []byte - result []string + result string offset int - shouldFail bool + shouldFail error }{ { offset: 0, rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, - result: []string{"a/b/c/d", "a"}, + result: "a/b/c/d", }, { offset: 14, @@ -41,36 +44,32 @@ func TestDecodeString(t *testing.T) { 0, 3, // Client ID - MSB+LSB 'h', 'e', 'y', // Client ID "zen"}, }, - result: []string{"hey"}, + result: "hey", }, - { offset: 2, rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, - result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"}, + result: "1/2/3/4/a/b/c/d/e/^/@/!", }, { offset: 0, rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, - result: []string{"x/y/z", "!@#$%^&"}, + result: "x/y/z", }, { offset: 0, rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, - result: []string{"a/b/c/d", "z"}, - shouldFail: true, + shouldFail: ErrOffsetBytesOutOfRange, }, { offset: 5, rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'}, - result: []string{"a/b/c/d", "x"}, - shouldFail: true, + shouldFail: ErrOffsetBytesOutOfRange, }, { offset: 9, rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'}, - result: []string{"a/b/c/d", "y"}, - shouldFail: true, + shouldFail: ErrOffsetUintOutOfRange, }, { offset: 17, @@ -86,20 +85,26 @@ func TestDecodeString(t *testing.T) { 0, 6, // Will Topic - MSB+LSB 'l', }, - result: []string{"lwt"}, - shouldFail: true, + shouldFail: ErrOffsetBytesOutOfRange, + }, + { + offset: 0, + rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100}, + shouldFail: ErrOffsetStrInvalidUTF8, }, } for i, wanted := range expect { - result, _, err := decodeString(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding string [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding string [i:%d]", i) - require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i) + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, _, err := decodeString(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + }) } } @@ -116,7 +121,7 @@ func TestDecodeBytes(t *testing.T) { result []uint8 next int offset int - shouldFail bool + shouldFail error }{ { rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session) @@ -132,33 +137,27 @@ func TestDecodeBytes(t *testing.T) { }, { rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), offset: 0, - shouldFail: true, + shouldFail: ErrOffsetBytesOutOfRange, }, { rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), offset: 8, - shouldFail: true, - }, - { - rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - offset: 0, - shouldFail: true, + shouldFail: ErrOffsetUintOutOfRange, }, } for i, wanted := range expect { - result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding bytes [i:%d]", i) - continue - } + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } - require.NoError(t, err, "Error decoding bytes [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) + require.NoError(t, err) + require.Equal(t, wanted.result, result) + }) } } @@ -174,7 +173,7 @@ func TestDecodeByte(t *testing.T) { rawBytes []byte result uint8 offset int - shouldFail bool + shouldFail error }{ { rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes @@ -198,22 +197,23 @@ func TestDecodeByte(t *testing.T) { }, { rawBytes: []byte{0, 4, 77, 80, 82, 84}, - result: uint8(0x00), offset: 8, - shouldFail: true, + shouldFail: ErrOffsetByteOutOfRange, }, } for i, wanted := range expect { - result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding byte [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding byte [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i) + t.Run(fmt.Sprint(i), func(t *testing.T) { + result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, i+1, offset) + }) } } @@ -229,7 +229,7 @@ func TestDecodeUint16(t *testing.T) { rawBytes []byte result uint16 offset int - shouldFail bool + shouldFail error }{ { rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, @@ -243,22 +243,24 @@ func TestDecodeUint16(t *testing.T) { }, { rawBytes: []byte{0, 7, 255, 47}, - result: uint16(0x761), offset: 8, - shouldFail: true, + shouldFail: ErrOffsetUintOutOfRange, }, } for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding uint16 [i:%d]", i) - continue - } + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } + - require.NoError(t, err, "Error decoding uint16 [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i) + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, i+2, offset) + }) } } @@ -274,7 +276,7 @@ func TestDecodeByteBool(t *testing.T) { rawBytes []byte result bool offset int - shouldFail bool + shouldFail error }{ { rawBytes: []byte{0x00, 0x00}, @@ -287,20 +289,22 @@ func TestDecodeByteBool(t *testing.T) { { rawBytes: []byte{0x01, 0x00}, offset: 5, - shouldFail: true, + shouldFail: ErrOffsetBoolOutOfRange, }, } for i, wanted := range expect { + t.Run(fmt.Sprint(i), func(t *testing.T) { result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding byte bool [i:%d]", i) - continue - } + if wanted.shouldFail != nil { + require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail) + return + } - require.NoError(t, err, "Error decoding byte bool [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i) + require.NoError(t, err) + require.Equal(t, wanted.result, result) + require.Equal(t, 1, offset) + }) } } diff --git a/server/internal/packets/packets.go b/server/internal/packets/packets.go index 9a9376b..158a71d 100644 --- a/server/internal/packets/packets.go +++ b/server/internal/packets/packets.go @@ -60,7 +60,6 @@ var ( // PACKETS ErrProtocolViolation = errors.New("protocol violation") - ErrOffsetStrOutOfRange = errors.New("offset string out of range") ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range") ErrOffsetByteOutOfRange = errors.New("offset byte out of range") ErrOffsetBoolOutOfRange = errors.New("offset bool out of range")