mirror of
https://github.com/gortc/stun.git
synced 2025-10-05 00:32:47 +08:00
all: remove usage of constant errors
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
package stun
|
package stun
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
||||||
type ErrorCodeAttribute struct {
|
type ErrorCodeAttribute struct {
|
||||||
Code ErrorCode
|
Code ErrorCode
|
||||||
@@ -15,21 +17,26 @@ const (
|
|||||||
errorCodeModulo = 100
|
errorCodeModulo = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrReasonLengthTooBig means that len(Reason) > 763 bytes.
|
||||||
|
var ErrReasonLengthTooBig = errors.New("reason for ERROR-CODE is too big")
|
||||||
|
|
||||||
// AddTo adds ERROR-CODE to m.
|
// AddTo adds ERROR-CODE to m.
|
||||||
func (c *ErrorCodeAttribute) AddTo(m *Message) error {
|
func (c *ErrorCodeAttribute) AddTo(m *Message) error {
|
||||||
value := make([]byte,
|
value := make([]byte, 0, errorCodeReasonMaxB)
|
||||||
errorCodeReasonStart, errorCodeReasonMaxB,
|
if len(c.Reason) > errorCodeReasonMaxB {
|
||||||
)
|
return ErrReasonLengthTooBig
|
||||||
|
}
|
||||||
|
value = value[:errorCodeReasonStart+len(c.Reason)]
|
||||||
number := byte(c.Code % errorCodeModulo) // error code modulo 100
|
number := byte(c.Code % errorCodeModulo) // error code modulo 100
|
||||||
class := byte(c.Code / errorCodeModulo) // hundred digit
|
class := byte(c.Code / errorCodeModulo) // hundred digit
|
||||||
value[errorCodeClassByte] = class
|
value[errorCodeClassByte] = class
|
||||||
value[errorCodeNumberByte] = number
|
value[errorCodeNumberByte] = number
|
||||||
value = append(value, c.Reason...)
|
copy(value[errorCodeReasonStart:], c.Reason)
|
||||||
m.Add(AttrErrorCode, value)
|
m.Add(AttrErrorCode, value)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFrom decodes ERROR-CODE from m.
|
// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
|
||||||
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
||||||
v, err := m.Get(AttrErrorCode)
|
v, err := m.Get(AttrErrorCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -39,10 +46,9 @@ func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
|||||||
class = uint16(v[errorCodeClassByte])
|
class = uint16(v[errorCodeClassByte])
|
||||||
number = uint16(v[errorCodeNumberByte])
|
number = uint16(v[errorCodeNumberByte])
|
||||||
code = int(class*errorCodeModulo + number)
|
code = int(class*errorCodeModulo + number)
|
||||||
reason = v[errorCodeReasonStart:]
|
|
||||||
)
|
)
|
||||||
c.Code = ErrorCode(code)
|
c.Code = ErrorCode(code)
|
||||||
c.Reason = reason
|
c.Reason = v[errorCodeReasonStart:]
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +57,7 @@ type ErrorCode int
|
|||||||
|
|
||||||
// ErrNoDefaultReason means that default reason for provided error code
|
// ErrNoDefaultReason means that default reason for provided error code
|
||||||
// is not defined in RFC.
|
// is not defined in RFC.
|
||||||
const ErrNoDefaultReason Error = "No default reason for ErrorCode"
|
var ErrNoDefaultReason = errors.New("No default reason for ErrorCode")
|
||||||
|
|
||||||
// AddTo adds ERROR-CODE with default reason to m. If there
|
// AddTo adds ERROR-CODE with default reason to m. If there
|
||||||
// is no default reason, returns ErrNoDefaultReason.
|
// is no default reason, returns ErrNoDefaultReason.
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package stun
|
package stun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
@@ -31,7 +32,7 @@ func isZeros(p net.IP) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
||||||
const ErrBadIPLength Error = "invalid length if IP value"
|
var ErrBadIPLength = errors.New("invalid length if IP value")
|
||||||
|
|
||||||
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
|
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
|
||||||
// if len(a.IP) is invalid.
|
// if len(a.IP) is invalid.
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package stun
|
package stun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@@ -118,8 +119,6 @@ func (t AttrType) String() string {
|
|||||||
// don't understand, but cannot successfully process a message if it
|
// don't understand, but cannot successfully process a message if it
|
||||||
// contains comprehension-required attributes that are not
|
// contains comprehension-required attributes that are not
|
||||||
// understood.
|
// understood.
|
||||||
//
|
|
||||||
// TODO(ar): Decide to use pointer or non-pointer RawAttribute.
|
|
||||||
type RawAttribute struct {
|
type RawAttribute struct {
|
||||||
Type AttrType
|
Type AttrType
|
||||||
Length uint16 // ignored while encoding
|
Length uint16 // ignored while encoding
|
||||||
@@ -157,7 +156,7 @@ func (a RawAttribute) String() string {
|
|||||||
|
|
||||||
// ErrAttributeNotFound means that attribute with provided attribute
|
// ErrAttributeNotFound means that attribute with provided attribute
|
||||||
// type does not exist in message.
|
// type does not exist in message.
|
||||||
const ErrAttributeNotFound Error = "Attribute not found"
|
var ErrAttributeNotFound = errors.New("Attribute not found")
|
||||||
|
|
||||||
// Get returns byte slice that represents attribute value,
|
// Get returns byte slice that represents attribute value,
|
||||||
// if there is no attribute with such type,
|
// if there is no attribute with such type,
|
||||||
|
@@ -212,6 +212,19 @@ func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkErrorCodeAttribute_GetFrom(b *testing.B) {
|
||||||
|
m := New()
|
||||||
|
b.ReportAllocs()
|
||||||
|
a := &ErrorCodeAttribute{
|
||||||
|
Code: 404,
|
||||||
|
Reason: []byte("not found!"),
|
||||||
|
}
|
||||||
|
a.AddTo(m)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
a.GetFrom(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMessage_AddErrorCode(t *testing.T) {
|
func TestMessage_AddErrorCode(t *testing.T) {
|
||||||
m := New()
|
m := New()
|
||||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
@@ -243,3 +256,11 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
|||||||
t.Error("bad reason", string(errCodeAttr.Reason))
|
t.Error("bad reason", string(errCodeAttr.Reason))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkMessage_GetNotFound(b *testing.B) {
|
||||||
|
m := New()
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.Get(AttrRealm)
|
||||||
|
}
|
||||||
|
}
|
@@ -1,14 +1,5 @@
|
|||||||
package stun
|
package stun
|
||||||
|
|
||||||
// Error is error type for constant errors in stun package.
|
|
||||||
//
|
|
||||||
// See http://dave.cheney.net/2016/04/07/constant-errors for more info.
|
|
||||||
type Error string
|
|
||||||
|
|
||||||
func (e Error) Error() string {
|
|
||||||
return string(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeErr records an error and place when it is occurred.
|
// DecodeErr records an error and place when it is occurred.
|
||||||
type DecodeErr struct {
|
type DecodeErr struct {
|
||||||
Place DecodeErrPlace
|
Place DecodeErrPlace
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
package stun
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestDecodeErr(t *testing.T) {
|
|
||||||
err := newDecodeErr("parent", "children", "message")
|
|
||||||
if !err.IsPlace(DecodeErrPlace{Parent: "parent", Children: "children"}) {
|
|
||||||
t.Error("isPlace test failed")
|
|
||||||
}
|
|
||||||
if !err.IsPlaceParent("parent") {
|
|
||||||
t.Error("parent test failed")
|
|
||||||
}
|
|
||||||
if !err.IsPlaceChildren("children") {
|
|
||||||
t.Error("children test failed")
|
|
||||||
}
|
|
||||||
if err.Error() != "BadFormat for parent/children: message" {
|
|
||||||
t.Error("bad Error string")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_Error(t *testing.T) {
|
|
||||||
if Error("error").Error() != "error" {
|
|
||||||
t.Error("bad Error string")
|
|
||||||
}
|
|
||||||
}
|
|
@@ -24,6 +24,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -261,11 +262,9 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
|||||||
return int64(n), m.Decode()
|
return int64(n), m.Decode()
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
|
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
|
||||||
// m.Raw to read header.
|
// m.Raw to read header.
|
||||||
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
|
var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
|
||||||
)
|
|
||||||
|
|
||||||
// Decode decodes m.Raw into m.
|
// Decode decodes m.Raw into m.
|
||||||
func (m *Message) Decode() error {
|
func (m *Message) Decode() error {
|
||||||
|
13
stun_test.go
13
stun_test.go
@@ -15,19 +15,12 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func bUint16(v uint16) string {
|
func bUint16(v uint16) string {
|
||||||
return fmt.Sprintf("0b%016s", strconv.FormatUint(uint64(v), 2))
|
return fmt.Sprintf("0b%016s", strconv.FormatUint(uint64(v), 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
|
||||||
log.SetLevel(log.DebugLevel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) reader() *bytes.Reader {
|
func (m *Message) reader() *bytes.Reader {
|
||||||
return bytes.NewReader(m.Raw)
|
return bytes.NewReader(m.Raw)
|
||||||
}
|
}
|
||||||
@@ -185,7 +178,7 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
|
|||||||
mDecoded := New()
|
mDecoded := New()
|
||||||
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
|
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
|
||||||
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
|
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
|
||||||
switch e := errors.Cause(err).(type) {
|
switch e := err.(type) {
|
||||||
case *DecodeErr:
|
case *DecodeErr:
|
||||||
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) {
|
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) {
|
||||||
t.Error(e, "bad place")
|
t.Error(e, "bad place")
|
||||||
@@ -213,7 +206,7 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) {
|
|||||||
bin.PutUint16(m.Raw[2:4], 5) // rewrite to bad length
|
bin.PutUint16(m.Raw[2:4], 5) // rewrite to bad length
|
||||||
mDecoded := New()
|
mDecoded := New()
|
||||||
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
|
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
|
||||||
switch e := errors.Cause(err).(type) {
|
switch e := err.(type) {
|
||||||
case *DecodeErr:
|
case *DecodeErr:
|
||||||
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) {
|
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) {
|
||||||
t.Error(e, "bad place")
|
t.Error(e, "bad place")
|
||||||
@@ -232,7 +225,7 @@ func (r unexpectedEOFReader) Read(b []byte) (int, error) {
|
|||||||
func TestMessage_ReadFromError(t *testing.T) {
|
func TestMessage_ReadFromError(t *testing.T) {
|
||||||
mDecoded := New()
|
mDecoded := New()
|
||||||
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
|
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
|
||||||
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
if err != io.ErrUnexpectedEOF {
|
||||||
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user