mirror of
https://github.com/gortc/stun.git
synced 2025-09-27 04:45:55 +08:00
all: refactor errors, improve coverage, refactor
This commit is contained in:
@@ -149,7 +149,7 @@ func (a RawAttribute) String() string {
|
|||||||
|
|
||||||
// ErrAttributeNotFound means that attribute with provided attribute
|
// ErrAttributeNotFound means that attribute with provided attribute
|
||||||
// type does not exist in message.
|
// type does not exist in message.
|
||||||
var ErrAttributeNotFound = errors.New("Attribute not found")
|
var ErrAttributeNotFound = errors.New("attribute not found")
|
||||||
|
|
||||||
// Get returns byte slice that represents attribute value,
|
// Get returns byte slice that represents attribute value,
|
||||||
// if there is no attribute with such type,
|
// if there is no attribute with such type,
|
||||||
|
@@ -74,7 +74,7 @@ type ErrorCode int
|
|||||||
|
|
||||||
// ErrNoDefaultReason means that default reason for provided error code
|
// ErrNoDefaultReason means that default reason for provided error code
|
||||||
// is not defined in RFC.
|
// is not defined in RFC.
|
||||||
var ErrNoDefaultReason = errors.New("No default reason for ErrorCode")
|
var ErrNoDefaultReason = errors.New("no default reason for ErrorCode")
|
||||||
|
|
||||||
// AddTo adds ERROR-CODE with default reason to m. If there
|
// AddTo adds ERROR-CODE with default reason to m. If there
|
||||||
// is no default reason, returns ErrNoDefaultReason.
|
// is no default reason, returns ErrNoDefaultReason.
|
||||||
|
@@ -98,4 +98,17 @@ func TestHelpersErrorHandling(t *testing.T) {
|
|||||||
if err := m.Parse(e); err != e.Err {
|
if err := m.Parse(e); err != e.Err {
|
||||||
t.Error(err, "!=", e.Err)
|
t.Error(err, "!=", e.Err)
|
||||||
}
|
}
|
||||||
|
t.Run("MustBuild", func(t *testing.T) {
|
||||||
|
t.Run("Positive", func(t *testing.T) {
|
||||||
|
MustBuild(NewTransactionIDSetter(transactionID{}))
|
||||||
|
})
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != e.Err {
|
||||||
|
t.Errorf("%s != %s",
|
||||||
|
p, e.Err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
MustBuild(e)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@@ -93,10 +93,7 @@ func (i *IntegrityErr) Error() string {
|
|||||||
|
|
||||||
func newHMAC(key, message []byte) []byte {
|
func newHMAC(key, message []byte) []byte {
|
||||||
mac := hmac.New(sha1.New, key)
|
mac := hmac.New(sha1.New, key)
|
||||||
_, err := mac.Write(message)
|
writeOrPanic(mac, message)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return mac.Sum(nil)
|
return mac.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -26,10 +26,7 @@ const (
|
|||||||
// NewTransactionID returns new random transaction ID using crypto/rand
|
// NewTransactionID returns new random transaction ID using crypto/rand
|
||||||
// as source.
|
// as source.
|
||||||
func NewTransactionID() (b [transactionIDSize]byte) {
|
func NewTransactionID() (b [transactionIDSize]byte) {
|
||||||
_, err := io.ReadFull(rand.Reader, b[:])
|
readFullOrPanic(rand.Reader, b[:])
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -392,7 +392,7 @@ func TestAttribute_Equal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_Equal(t *testing.T) {
|
func TestMessage_Equal(t *testing.T) {
|
||||||
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||||
attrs := Attributes{attr}
|
attrs := Attributes{attr}
|
||||||
a := &Message{Attributes: attrs, Length: 4 + 2}
|
a := &Message{Attributes: attrs, Length: 4 + 2}
|
||||||
b := &Message{Attributes: attrs, Length: 4 + 2}
|
b := &Message{Attributes: attrs, Length: 4 + 2}
|
||||||
@@ -412,11 +412,18 @@ func TestMessage_Equal(t *testing.T) {
|
|||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
tAttrs := Attributes{
|
tAttrs := Attributes{
|
||||||
{Length: 1, Value: []byte{0x1}},
|
{Length: 1, Value: []byte{0x1}, Type: 0x1},
|
||||||
}
|
}
|
||||||
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||||
t.Error("should not equal")
|
t.Error("should not equal")
|
||||||
}
|
}
|
||||||
|
tAttrs = Attributes{
|
||||||
|
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
|
||||||
|
}
|
||||||
|
if a.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||||
|
t.Error("should not equal")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessageGrow(t *testing.T) {
|
func TestMessageGrow(t *testing.T) {
|
||||||
|
25
stun.go
25
stun.go
@@ -11,22 +11,41 @@
|
|||||||
// package for example of stun extension implementation.
|
// package for example of stun extension implementation.
|
||||||
package stun
|
package stun
|
||||||
|
|
||||||
import "encoding/binary"
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
// bin is shorthand to binary.BigEndian.
|
// bin is shorthand to binary.BigEndian.
|
||||||
var bin = binary.BigEndian
|
var bin = binary.BigEndian
|
||||||
|
|
||||||
|
func readFullOrPanic(r io.Reader, v []byte) int {
|
||||||
|
n, err := io.ReadFull(r, v)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOrPanic(w io.Writer, v []byte) int {
|
||||||
|
n, err := w.Write(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
// IANA assigned ports for "stun" protocol.
|
// IANA assigned ports for "stun" protocol.
|
||||||
const (
|
const (
|
||||||
DefaultPort = 3478
|
DefaultPort = 3478
|
||||||
DefaultTLSPort = 5349
|
DefaultTLSPort = 5349
|
||||||
)
|
)
|
||||||
|
|
||||||
type transactionIDSetter bool
|
type transactionIDSetter struct{}
|
||||||
|
|
||||||
func (transactionIDSetter) AddTo(m *Message) error {
|
func (transactionIDSetter) AddTo(m *Message) error {
|
||||||
return m.NewTransactionID()
|
return m.NewTransactionID()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransactionID is Setter for m.TransactionID.
|
// TransactionID is Setter for m.TransactionID.
|
||||||
var TransactionID Setter = transactionIDSetter(true)
|
var TransactionID Setter = transactionIDSetter{}
|
||||||
|
36
stun_test.go
Normal file
36
stun_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type errorReader struct{}
|
||||||
|
|
||||||
|
func (errorReader) Read([]byte) (int, error) {
|
||||||
|
return 0, errors.New("failed to read")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFullHelper(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("should panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
readFullOrPanic(errorReader{}, make([]byte, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
type errorWriter struct{}
|
||||||
|
|
||||||
|
func (errorWriter) Write([]byte) (int, error) {
|
||||||
|
return 0, errors.New("failed to write")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteHelper(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Error("should panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
writeOrPanic(errorWriter{}, make([]byte, 1))
|
||||||
|
}
|
@@ -109,7 +109,11 @@ func (a *XORMappedAddress) GetFromAs(m *Message, t AttrType) error {
|
|||||||
return io.ErrUnexpectedEOF
|
return io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
if len(v[4:]) > len(a.IP) {
|
if len(v[4:]) > len(a.IP) {
|
||||||
return errors.New("Bad format for XOR-MAPPED-ADDRESS")
|
return &AttrOverflowErr{
|
||||||
|
Got: len(v[4:]),
|
||||||
|
Type: AttrXORMappedAddress,
|
||||||
|
Max: len(a.IP),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
||||||
xorValue := make([]byte, 4+transactionIDSize)
|
xorValue := make([]byte, 4+transactionIDSize)
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -63,6 +64,29 @@ func TestXORMappedAddress_GetFrom(t *testing.T) {
|
|||||||
if addr.Port != 48583 {
|
if addr.Port != 48583 {
|
||||||
t.Error("bad Port", addr.Port, "!=", 48583)
|
t.Error("bad Port", addr.Port, "!=", 48583)
|
||||||
}
|
}
|
||||||
|
t.Run("UnexpectedEOF", func(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
// {0, 1} is correct addr family.
|
||||||
|
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4})
|
||||||
|
addr := new(XORMappedAddress)
|
||||||
|
if err = addr.GetFrom(m); err != io.ErrUnexpectedEOF {
|
||||||
|
t.Errorf("len(v) = 4 should render <%s> error, got <%s>",
|
||||||
|
io.ErrUnexpectedEOF, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("AttrOverflowErr", func(t *testing.T) {
|
||||||
|
m := New()
|
||||||
|
// {0, 1} is correct addr family.
|
||||||
|
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
|
||||||
|
addr := new(XORMappedAddress)
|
||||||
|
err := addr.GetFrom(m)
|
||||||
|
if _, ok := err.(*AttrOverflowErr); !ok {
|
||||||
|
t.Errorf("should render AttrOverflowErr error, got <%s> (%T)",
|
||||||
|
err, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user