all: refactor errors, improve coverage, refactor

This commit is contained in:
Aleksandr Razumov
2017-06-26 03:18:07 +02:00
parent 5505d6838c
commit f139299e92
10 changed files with 113 additions and 16 deletions

View File

@@ -149,7 +149,7 @@ func (a RawAttribute) String() string {
// ErrAttributeNotFound means that attribute with provided attribute // ErrAttributeNotFound means that attribute with provided attribute
// type does not exist in message. // type does not exist in message.
var ErrAttributeNotFound = errors.New("Attribute not found") var ErrAttributeNotFound = errors.New("attribute not found")
// Get returns byte slice that represents attribute value, // Get returns byte slice that represents attribute value,
// if there is no attribute with such type, // if there is no attribute with such type,

View File

@@ -74,7 +74,7 @@ type ErrorCode int
// ErrNoDefaultReason means that default reason for provided error code // ErrNoDefaultReason means that default reason for provided error code
// is not defined in RFC. // is not defined in RFC.
var ErrNoDefaultReason = errors.New("No default reason for ErrorCode") var ErrNoDefaultReason = errors.New("no default reason for ErrorCode")
// AddTo adds ERROR-CODE with default reason to m. If there // AddTo adds ERROR-CODE with default reason to m. If there
// is no default reason, returns ErrNoDefaultReason. // is no default reason, returns ErrNoDefaultReason.

View File

@@ -98,4 +98,17 @@ func TestHelpersErrorHandling(t *testing.T) {
if err := m.Parse(e); err != e.Err { if err := m.Parse(e); err != e.Err {
t.Error(err, "!=", e.Err) t.Error(err, "!=", e.Err)
} }
t.Run("MustBuild", func(t *testing.T) {
t.Run("Positive", func(t *testing.T) {
MustBuild(NewTransactionIDSetter(transactionID{}))
})
defer func() {
if p := recover(); p != e.Err {
t.Errorf("%s != %s",
p, e.Err,
)
}
}()
MustBuild(e)
})
} }

View File

@@ -93,10 +93,7 @@ func (i *IntegrityErr) Error() string {
func newHMAC(key, message []byte) []byte { func newHMAC(key, message []byte) []byte {
mac := hmac.New(sha1.New, key) mac := hmac.New(sha1.New, key)
_, err := mac.Write(message) writeOrPanic(mac, message)
if err != nil {
panic(err)
}
return mac.Sum(nil) return mac.Sum(nil)
} }

View File

@@ -26,10 +26,7 @@ const (
// NewTransactionID returns new random transaction ID using crypto/rand // NewTransactionID returns new random transaction ID using crypto/rand
// as source. // as source.
func NewTransactionID() (b [transactionIDSize]byte) { func NewTransactionID() (b [transactionIDSize]byte) {
_, err := io.ReadFull(rand.Reader, b[:]) readFullOrPanic(rand.Reader, b[:])
if err != nil {
panic(err)
}
return b return b
} }

View File

@@ -392,7 +392,7 @@ func TestAttribute_Equal(t *testing.T) {
} }
func TestMessage_Equal(t *testing.T) { func TestMessage_Equal(t *testing.T) {
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attrs := Attributes{attr} attrs := Attributes{attr}
a := &Message{Attributes: attrs, Length: 4 + 2} a := &Message{Attributes: attrs, Length: 4 + 2}
b := &Message{Attributes: attrs, Length: 4 + 2} b := &Message{Attributes: attrs, Length: 4 + 2}
@@ -412,11 +412,18 @@ func TestMessage_Equal(t *testing.T) {
t.Error("should not equal") t.Error("should not equal")
} }
tAttrs := Attributes{ tAttrs := Attributes{
{Length: 1, Value: []byte{0x1}}, {Length: 1, Value: []byte{0x1}, Type: 0x1},
} }
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal") t.Error("should not equal")
} }
tAttrs = Attributes{
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
}
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal")
}
} }
func TestMessageGrow(t *testing.T) { func TestMessageGrow(t *testing.T) {

25
stun.go
View File

@@ -11,22 +11,41 @@
// package for example of stun extension implementation. // package for example of stun extension implementation.
package stun package stun
import "encoding/binary" import (
"encoding/binary"
"io"
)
// bin is shorthand to binary.BigEndian. // bin is shorthand to binary.BigEndian.
var bin = binary.BigEndian var bin = binary.BigEndian
func readFullOrPanic(r io.Reader, v []byte) int {
n, err := io.ReadFull(r, v)
if err != nil {
panic(err)
}
return n
}
func writeOrPanic(w io.Writer, v []byte) int {
n, err := w.Write(v)
if err != nil {
panic(err)
}
return n
}
// IANA assigned ports for "stun" protocol. // IANA assigned ports for "stun" protocol.
const ( const (
DefaultPort = 3478 DefaultPort = 3478
DefaultTLSPort = 5349 DefaultTLSPort = 5349
) )
type transactionIDSetter bool type transactionIDSetter struct{}
func (transactionIDSetter) AddTo(m *Message) error { func (transactionIDSetter) AddTo(m *Message) error {
return m.NewTransactionID() return m.NewTransactionID()
} }
// TransactionID is Setter for m.TransactionID. // TransactionID is Setter for m.TransactionID.
var TransactionID Setter = transactionIDSetter(true) var TransactionID Setter = transactionIDSetter{}

36
stun_test.go Normal file
View File

@@ -0,0 +1,36 @@
package stun
import (
"errors"
"testing"
)
type errorReader struct{}
func (errorReader) Read([]byte) (int, error) {
return 0, errors.New("failed to read")
}
func TestReadFullHelper(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("should panic")
}
}()
readFullOrPanic(errorReader{}, make([]byte, 1))
}
type errorWriter struct{}
func (errorWriter) Write([]byte) (int, error) {
return 0, errors.New("failed to write")
}
func TestWriteHelper(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("should panic")
}
}()
writeOrPanic(errorWriter{}, make([]byte, 1))
}

View File

@@ -109,7 +109,11 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
if len(v[4:]) > len(a.IP) { if len(v[4:]) > len(a.IP) {
return errors.New("Bad format for XOR-MAPPED-ADDRESS") return &AttrOverflowErr{
Got: len(v[4:]),
Type: AttrXORMappedAddress,
Max: len(a.IP),
}
} }
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16) a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+transactionIDSize) xorValue := make([]byte, 4+transactionIDSize)

View File

@@ -5,6 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"io"
"net" "net"
"testing" "testing"
) )
@@ -63,6 +64,29 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
if addr.Port != 48583 { if addr.Port != 48583 {
t.Error("bad Port", addr.Port, "!=", 48583) t.Error("bad Port", addr.Port, "!=", 48583)
} }
t.Run("UnexpectedEOF", func(t *testing.T) {
m := New()
// {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4})
addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err != io.ErrUnexpectedEOF {
t.Errorf("len(v) = 4 should render <%s> error, got <%s>",
io.ErrUnexpectedEOF, err,
)
}
})
t.Run("AttrOverflowErr", func(t *testing.T) {
m := New()
// {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
addr := new(XORMappedAddress)
err := addr.GetFrom(m)
if _, ok := err.(*AttrOverflowErr); !ok {
t.Errorf("should render AttrOverflowErr error, got <%s> (%T)",
err, err,
)
}
})
} }
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {