attrs: add CheckOverflow and IsAttrSizeOverflow

This commit is contained in:
Aleksandr Razumov
2018-08-02 18:06:32 +03:00
parent b5912456d3
commit 872811ad91
15 changed files with 125 additions and 92 deletions

View File

@@ -71,4 +71,6 @@ test-integration:
@cd e2e && bash ./test.sh @cd e2e && bash ./test.sh
prepush: test lint test-integration prepush: test lint test-integration
check-api: check-api:
@api -c api/stun1.txt,api/stun1.12.txt github.com/gortc/stun @api -c api/stun1.txt github.com/gortc/stun
write-api:
@api github.com/gortc/stun > api/stun1.txt

View File

@@ -1 +0,0 @@
pkg github.com/gortc/stun, method (*Message) ForEach(AttrType, func(*Message) error) error

View File

@@ -103,11 +103,13 @@ pkg github.com/gortc/stun, const MethodSend Method
pkg github.com/gortc/stun, const TransactionIDSize = 12 pkg github.com/gortc/stun, const TransactionIDSize = 12
pkg github.com/gortc/stun, const TransactionIDSize ideal-int pkg github.com/gortc/stun, const TransactionIDSize ideal-int
pkg github.com/gortc/stun, func Build(...Setter) (*Message, error) pkg github.com/gortc/stun, func Build(...Setter) (*Message, error)
pkg github.com/gortc/stun, func CheckOverflow(AttrType, int, int) error
pkg github.com/gortc/stun, func CheckSize(AttrType, int, int) error pkg github.com/gortc/stun, func CheckSize(AttrType, int, int) error
pkg github.com/gortc/stun, func Decode([]uint8, *Message) error pkg github.com/gortc/stun, func Decode([]uint8, *Message) error
pkg github.com/gortc/stun, func Dial(string, string) (*Client, error) pkg github.com/gortc/stun, func Dial(string, string) (*Client, error)
pkg github.com/gortc/stun, func FingerprintValue([]uint8) uint32 pkg github.com/gortc/stun, func FingerprintValue([]uint8) uint32
pkg github.com/gortc/stun, func IsAttrSizeInvalid(error) bool pkg github.com/gortc/stun, func IsAttrSizeInvalid(error) bool
pkg github.com/gortc/stun, func IsAttrSizeOverflow(error) bool
pkg github.com/gortc/stun, func IsMessage([]uint8) bool pkg github.com/gortc/stun, func IsMessage([]uint8) bool
pkg github.com/gortc/stun, func MustBuild(...Setter) *Message pkg github.com/gortc/stun, func MustBuild(...Setter) *Message
pkg github.com/gortc/stun, func New() *Message pkg github.com/gortc/stun, func New() *Message
@@ -146,6 +148,7 @@ pkg github.com/gortc/stun, method (*Message) Contains(AttrType) bool
pkg github.com/gortc/stun, method (*Message) Decode() error pkg github.com/gortc/stun, method (*Message) Decode() error
pkg github.com/gortc/stun, method (*Message) Encode() pkg github.com/gortc/stun, method (*Message) Encode()
pkg github.com/gortc/stun, method (*Message) Equal(*Message) bool pkg github.com/gortc/stun, method (*Message) Equal(*Message) bool
pkg github.com/gortc/stun, method (*Message) ForEach(AttrType, func(*Message) error) error
pkg github.com/gortc/stun, method (*Message) Get(AttrType) ([]uint8, error) pkg github.com/gortc/stun, method (*Message) Get(AttrType) ([]uint8, error)
pkg github.com/gortc/stun, method (*Message) NewTransactionID() error pkg github.com/gortc/stun, method (*Message) NewTransactionID() error
pkg github.com/gortc/stun, method (*Message) Parse(...Getter) error pkg github.com/gortc/stun, method (*Message) Parse(...Getter) error
@@ -169,8 +172,6 @@ pkg github.com/gortc/stun, method (*UnknownAttributes) GetFrom(*Message) error
pkg github.com/gortc/stun, method (*Username) GetFrom(*Message) error pkg github.com/gortc/stun, method (*Username) GetFrom(*Message) error
pkg github.com/gortc/stun, method (*XORMappedAddress) GetFrom(*Message) error pkg github.com/gortc/stun, method (*XORMappedAddress) GetFrom(*Message) error
pkg github.com/gortc/stun, method (*XORMappedAddress) GetFromAs(*Message, AttrType) error pkg github.com/gortc/stun, method (*XORMappedAddress) GetFromAs(*Message, AttrType) error
pkg github.com/gortc/stun, method (AttrLengthErr) Error() string
pkg github.com/gortc/stun, method (AttrOverflowErr) Error() string
pkg github.com/gortc/stun, method (AttrType) Optional() bool pkg github.com/gortc/stun, method (AttrType) Optional() bool
pkg github.com/gortc/stun, method (AttrType) Required() bool pkg github.com/gortc/stun, method (AttrType) Required() bool
pkg github.com/gortc/stun, method (AttrType) String() string pkg github.com/gortc/stun, method (AttrType) String() string
@@ -221,14 +222,6 @@ pkg github.com/gortc/stun, type AgentOptions struct, Handler Handler
pkg github.com/gortc/stun, type AlternateServer struct pkg github.com/gortc/stun, type AlternateServer struct
pkg github.com/gortc/stun, type AlternateServer struct, IP net.IP pkg github.com/gortc/stun, type AlternateServer struct, IP net.IP
pkg github.com/gortc/stun, type AlternateServer struct, Port int pkg github.com/gortc/stun, type AlternateServer struct, Port int
pkg github.com/gortc/stun, type AttrLengthErr struct
pkg github.com/gortc/stun, type AttrLengthErr struct, Attr AttrType
pkg github.com/gortc/stun, type AttrLengthErr struct, Expected int
pkg github.com/gortc/stun, type AttrLengthErr struct, Got int
pkg github.com/gortc/stun, type AttrOverflowErr struct
pkg github.com/gortc/stun, type AttrOverflowErr struct, Got int
pkg github.com/gortc/stun, type AttrOverflowErr struct, Max int
pkg github.com/gortc/stun, type AttrOverflowErr struct, Type AttrType
pkg github.com/gortc/stun, type AttrType uint16 pkg github.com/gortc/stun, type AttrType uint16
pkg github.com/gortc/stun, type Attributes []RawAttribute pkg github.com/gortc/stun, type Attributes []RawAttribute
pkg github.com/gortc/stun, type Checker interface { Check } pkg github.com/gortc/stun, type Checker interface { Check }
@@ -309,6 +302,8 @@ pkg github.com/gortc/stun, var BindingSuccess MessageType
pkg github.com/gortc/stun, var ErrAgentClosed error pkg github.com/gortc/stun, var ErrAgentClosed error
pkg github.com/gortc/stun, var ErrAttrSizeInvalid error pkg github.com/gortc/stun, var ErrAttrSizeInvalid error
pkg github.com/gortc/stun, var ErrAttributeNotFound error pkg github.com/gortc/stun, var ErrAttributeNotFound error
pkg github.com/gortc/stun, var ErrAttributeSizeInvalid error
pkg github.com/gortc/stun, var ErrAttributeSizeOverflow error
pkg github.com/gortc/stun, var ErrBadIPLength error pkg github.com/gortc/stun, var ErrBadIPLength error
pkg github.com/gortc/stun, var ErrBadUnknownAttrsSize error pkg github.com/gortc/stun, var ErrBadUnknownAttrsSize error
pkg github.com/gortc/stun, var ErrClientClosed error pkg github.com/gortc/stun, var ErrClientClosed error

View File

@@ -172,34 +172,6 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
return v.Value, nil return v.Value, nil
} }
// AttrOverflowErr occurs when len(v) > Max.
type AttrOverflowErr struct {
Type AttrType
Max int
Got int
}
func (e AttrOverflowErr) Error() string {
return fmt.Sprintf("incorrect length of %s attribute: %d exceeds maximum %d",
e.Type, e.Got, e.Max,
)
}
// AttrLengthErr means that length for attribute is invalid.
type AttrLengthErr struct {
Attr AttrType
Got int
Expected int
}
func (e AttrLengthErr) Error() string {
return fmt.Sprintf("incorrect length of %s attribute: got %d, expected %d",
e.Attr,
e.Got,
e.Expected,
)
}
// STUN aligns attributes on 32-bit boundaries, attributes whose content // STUN aligns attributes on 32-bit boundaries, attributes whose content
// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of // is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of
// padding so that its value contains a multiple of 4 bytes. The // padding so that its value contains a multiple of 4 bytes. The

33
attributes_debug.go Normal file
View File

@@ -0,0 +1,33 @@
// +build debug
package stun
import "fmt"
// AttrOverflowErr occurs when len(v) > Max.
type AttrOverflowErr struct {
Type AttrType
Max int
Got int
}
func (e AttrOverflowErr) Error() string {
return fmt.Sprintf("incorrect length of %s attribute: %d exceeds maximum %d",
e.Type, e.Got, e.Max,
)
}
// AttrLengthErr means that length for attribute is invalid.
type AttrLengthErr struct {
Attr AttrType
Got int
Expected int
}
func (e AttrLengthErr) Error() string {
return fmt.Sprintf("incorrect length of %s attribute: got %d, expected %d",
e.Attr,
e.Got,
e.Expected,
)
}

27
attributes_debug_test.go Normal file
View File

@@ -0,0 +1,27 @@
// +build debug
package stun
import "testing"
func TestAttrOverflowErr_Error(t *testing.T) {
err := AttrOverflowErr{
Got: 100,
Max: 50,
Type: AttrLifetime,
}
if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" {
t.Error("bad error string", err)
}
}
func TestAttrLengthErr_Error(t *testing.T) {
err := AttrLengthErr{
Attr: AttrErrorCode,
Expected: 15,
Got: 99,
}
if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" {
t.Errorf("bad error string: %s", err)
}
}

View File

@@ -69,28 +69,6 @@ func TestPadding(t *testing.T) {
} }
} }
func TestAttrLengthError_Error(t *testing.T) {
err := AttrOverflowErr{
Got: 100,
Max: 50,
Type: AttrLifetime,
}
if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" {
t.Error("bad error string", err)
}
}
func TestAttrLengthErr_Error(t *testing.T) {
err := AttrLengthErr{
Attr: AttrErrorCode,
Expected: 15,
Got: 99,
}
if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" {
t.Errorf("bad error string: %s", err)
}
}
func TestAttrTypeRange(t *testing.T) { func TestAttrTypeRange(t *testing.T) {
for _, a := range []AttrType{ for _, a := range []AttrType{
AttrPriority, AttrPriority,

View File

@@ -9,7 +9,7 @@ func CheckSize(_ AttrType, got, expected int) error {
if got == expected { if got == expected {
return nil return nil
} }
return ErrAttrSizeInvalid return ErrAttributeSizeInvalid
} }
func checkHMAC(got, expected []byte) error { func checkHMAC(got, expected []byte) error {
@@ -28,5 +28,18 @@ func checkFingerprint(got, expected uint32) error {
// IsAttrSizeInvalid returns true if error means that attribute size is invalid. // IsAttrSizeInvalid returns true if error means that attribute size is invalid.
func IsAttrSizeInvalid(err error) bool { func IsAttrSizeInvalid(err error) bool {
return err == ErrAttrSizeInvalid return err == ErrAttributeSizeInvalid
}
// CheckOverflow returns ErrAttributeSizeOverflow if got is bigger that max.
func CheckOverflow(_ AttrType, got, max int) error {
if got <= max {
return nil
}
return ErrAttributeSizeOverflow
}
// IsAttrSizeOverflow returns true if error means that attribute size is too big.
func IsAttrSizeOverflow(err error) bool {
return err == ErrAttributeSizeOverflow
} }

View File

@@ -41,3 +41,21 @@ func IsAttrSizeInvalid(err error) bool {
_, ok := err.(*AttrLengthErr) _, ok := err.(*AttrLengthErr)
return ok return ok
} }
// CheckOverflow returns *AttrOverflowErr if got is bigger that max.
func CheckOverflow(t AttrType, got, max int) error {
if got <= max {
return nil
}
return &AttrOverflowErr{
Type: t,
Got: got,
Max: max,
}
}
// IsAttrSizeOverflow returns true if error means that attribute size is too big.
func IsAttrSizeOverflow(err error) bool {
_, ok := err.(*AttrOverflowErr)
return ok
}

View File

@@ -30,12 +30,11 @@ const (
// AddTo adds ERROR-CODE to m. // AddTo adds ERROR-CODE to m.
func (c ErrorCodeAttribute) AddTo(m *Message) error { func (c ErrorCodeAttribute) AddTo(m *Message) error {
value := make([]byte, 0, errorCodeReasonMaxB) value := make([]byte, 0, errorCodeReasonMaxB)
if len(c.Reason) > errorCodeReasonMaxB { if err := CheckOverflow(AttrErrorCode,
return &AttrOverflowErr{ len(c.Reason)+errorCodeReasonStart,
Got: len(c.Reason) + errorCodeReasonStart, errorCodeReasonMaxB+errorCodeReasonStart,
Max: errorCodeReasonMaxB + errorCodeReasonStart, ); err != nil {
Type: AttrErrorCode, return err
}
} }
value = value[:errorCodeReasonStart+len(c.Reason)] value = value[:errorCodeReasonStart+len(c.Reason)]
number := byte(c.Code % errorCodeModulo) // error code modulo 100 number := byte(c.Code % errorCodeModulo) // error code modulo 100

View File

@@ -56,4 +56,12 @@ func newAttrDecodeErr(children, message string) *DecodeErr {
} }
// ErrAttrSizeInvalid means that decoded attribute size is invalid. // ErrAttrSizeInvalid means that decoded attribute size is invalid.
//
// DEPRECATED: use ErrAttributeSizeInvalid.
var ErrAttrSizeInvalid = errors.New("attribute size is invalid") var ErrAttrSizeInvalid = errors.New("attribute size is invalid")
// ErrAttributeSizeInvalid means that decoded attribute size is invalid.
var ErrAttributeSizeInvalid = ErrAttrSizeInvalid
// ErrAttributeSizeOverflow means that decoded attribute size is too big.
var ErrAttributeSizeOverflow = errors.New("attribute size overflow")

View File

@@ -111,12 +111,8 @@ type TextAttribute []byte
// AddToAs adds attribute with type t to m, checking maximum length. If maxLen // AddToAs adds attribute with type t to m, checking maximum length. If maxLen
// is less than 0, no check is performed. // is less than 0, no check is performed.
func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error { func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
if maxLen > 0 && len(v) > maxLen { if err := CheckOverflow(t, len(v), maxLen); err != nil {
return &AttrOverflowErr{ return err
Max: maxLen,
Got: len(v),
Type: t,
}
} }
m.Add(t, v) m.Add(t, v)
return nil return nil

View File

@@ -38,7 +38,7 @@ func TestSoftware_GetFrom(t *testing.T) {
func TestSoftware_AddTo_Invalid(t *testing.T) { func TestSoftware_AddTo_Invalid(t *testing.T) {
m := New() m := New()
s := make(Software, 1024) s := make(Software, 1024)
if err, ok := s.AddTo(m).(*AttrOverflowErr); !ok { if err := s.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
} }
if err := s.GetFrom(m); err != ErrAttributeNotFound { if err := s.GetFrom(m); err != ErrAttributeNotFound {
@@ -88,8 +88,8 @@ func TestUsername(t *testing.T) {
m.WriteHeader() m.WriteHeader()
t.Run("Bad length", func(t *testing.T) { t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600) badU := make(Username, 600)
if err, ok := badU.AddTo(m).(*AttrOverflowErr); !ok { if err := badU.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("expected length error, got %v", err) t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
} }
}) })
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
@@ -164,7 +164,7 @@ func TestRealm_GetFrom(t *testing.T) {
func TestRealm_AddTo_Invalid(t *testing.T) { func TestRealm_AddTo_Invalid(t *testing.T) {
m := New() m := New()
r := make(Realm, 1024) r := make(Realm, 1024)
if err, ok := r.AddTo(m).(*AttrOverflowErr); !ok || err.Type != AttrRealm { if err := r.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
} }
if err := r.GetFrom(m); err != ErrAttributeNotFound { if err := r.GetFrom(m); err != ErrAttributeNotFound {
@@ -205,7 +205,7 @@ func TestNonce_GetFrom(t *testing.T) {
func TestNonce_AddTo_Invalid(t *testing.T) { func TestNonce_AddTo_Invalid(t *testing.T) {
m := New() m := New()
n := make(Nonce, 1024) n := make(Nonce, 1024)
if err, ok := n.AddTo(m).(*AttrOverflowErr); !ok || err.Type != AttrNonce { if err := n.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
} }
if err := n.GetFrom(m); err != ErrAttributeNotFound { if err := n.GetFrom(m); err != ErrAttributeNotFound {

View File

@@ -108,12 +108,8 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
if len(v) <= 4 { if len(v) <= 4 {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
if len(v[4:]) > len(a.IP) { if err := CheckOverflow(t, len(v[4:]), len(a.IP)); err != nil {
return &AttrOverflowErr{ return err
Got: len(v[4:]),
Type: AttrXORMappedAddress,
Max: len(a.IP),
}
} }
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16) a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+TransactionIDSize) xorValue := make([]byte, 4+TransactionIDSize)

View File

@@ -80,11 +80,8 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
// {0, 1} is correct addr family. // {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4}) m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
err := addr.GetFrom(m) if err := addr.GetFrom(m); !IsAttrSizeOverflow(err) {
if _, ok := err.(*AttrOverflowErr); !ok { t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
t.Errorf("should render AttrOverflowErr error, got <%s> (%T)",
err, err,
)
} }
}) })
} }