diff --git a/internal/util/util.go b/internal/util/util.go index 0c38ba3c..aedeb3c5 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -15,3 +15,11 @@ func RandSeq(n int) string { } return string(b) } + +// GetPadding Returns the padding required to make the length a multiple of 4 +func GetPadding(len int) int { + if len%4 == 0 { + return 0 + } + return 4 - (len % 4) +} diff --git a/internal/util/util_test.go b/internal/util/util_test.go new file mode 100644 index 00000000..175ccebb --- /dev/null +++ b/internal/util/util_test.go @@ -0,0 +1,40 @@ +package util + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandSeq(t *testing.T) { + if len(RandSeq(10)) != 10 { + t.Errorf("RandSeq return invalid length") + } + + var isLetter = regexp.MustCompile(`^[a-zA-Z]+$`).MatchString + if !isLetter(RandSeq(10)) { + t.Errorf("RandSeq should be AlphaNumeric only") + } +} + +func TestGetPadding(t *testing.T) { + assert := assert.New(t) + type testCase struct { + input int + result int + } + + cases := []testCase{ + {input: 0, result: 0}, + {input: 1, result: 3}, + {input: 2, result: 2}, + {input: 3, result: 1}, + {input: 4, result: 0}, + {input: 100, result: 0}, + {input: 500, result: 0}, + } + for _, testCase := range cases { + assert.Equalf(GetPadding(testCase.input), testCase.result, "Test case returned wrong value for input %d", testCase.input) + } +} diff --git a/pkg/rtcp/goodbye.go b/pkg/rtcp/goodbye.go index 2e9a76d3..d8e77f06 100644 --- a/pkg/rtcp/goodbye.go +++ b/pkg/rtcp/goodbye.go @@ -2,6 +2,8 @@ package rtcp import ( "encoding/binary" + + "github.com/pions/webrtc/internal/util" ) // The Goodbye packet indicates that one or more sources are no longer active. @@ -49,10 +51,7 @@ func (g Goodbye) Marshal() ([]byte, error) { rawPacket = append(rawPacket, reason...) // align to 32-bit boundary - if len(rawPacket)%4 != 0 { - padCount := 4 - len(rawPacket)%4 - rawPacket = append(rawPacket, make([]byte, padCount)...) - } + rawPacket = append(rawPacket, make([]byte, util.GetPadding(len(rawPacket)))...) } h := Header{ @@ -96,7 +95,7 @@ func (g *Goodbye) Unmarshal(rawPacket []byte) error { return errWrongType } - if len(rawPacket)%4 != 0 { + if util.GetPadding(len(rawPacket)) != 0 { return errPacketTooShort } diff --git a/pkg/rtcp/source_description.go b/pkg/rtcp/source_description.go index 6266f303..0bfe0a1c 100644 --- a/pkg/rtcp/source_description.go +++ b/pkg/rtcp/source_description.go @@ -2,6 +2,8 @@ package rtcp import ( "encoding/binary" + + "github.com/pions/webrtc/internal/util" ) // SDESType is the item type used in the RTCP SDES control packet. @@ -187,10 +189,7 @@ func (s SourceDescriptionChunk) Marshal() ([]byte, error) { rawPacket = append(rawPacket, uint8(SDESEnd)) // additional null octets MUST be included if needed to pad until the next 32-bit boundary - if size := len(rawPacket); size%4 != 0 { - padding := make([]byte, 4-size%4) - rawPacket = append(rawPacket, padding...) - } + rawPacket = append(rawPacket, make([]byte, util.GetPadding(len(rawPacket)))...) return rawPacket, nil } @@ -236,9 +235,7 @@ func (s SourceDescriptionChunk) len() int { len += sdesTypeLen // for terminating null octet // align to 32-bit boundary - if len%4 != 0 { - len += 4 - (len % 4) - } + len += util.GetPadding(len) return len }