all: remove usage of constant errors

This commit is contained in:
Aleksandr Razumov
2017-02-11 06:54:37 +03:00
parent 03525bd44e
commit f1b2352162
8 changed files with 46 additions and 61 deletions

View File

@@ -1,5 +1,7 @@
package stun package stun
import "errors"
// ErrorCodeAttribute represents ERROR-CODE attribute. // ErrorCodeAttribute represents ERROR-CODE attribute.
type ErrorCodeAttribute struct { type ErrorCodeAttribute struct {
Code ErrorCode Code ErrorCode
@@ -15,21 +17,26 @@ const (
errorCodeModulo = 100 errorCodeModulo = 100
) )
// ErrReasonLengthTooBig means that len(Reason) > 763 bytes.
var ErrReasonLengthTooBig = errors.New("reason for ERROR-CODE is too big")
// AddTo adds ERROR-CODE to m. // AddTo adds ERROR-CODE to m.
func (c *ErrorCodeAttribute) AddTo(m *Message) error { func (c *ErrorCodeAttribute) AddTo(m *Message) error {
value := make([]byte, value := make([]byte, 0, errorCodeReasonMaxB)
errorCodeReasonStart, errorCodeReasonMaxB, if len(c.Reason) > errorCodeReasonMaxB {
) return ErrReasonLengthTooBig
}
value = value[:errorCodeReasonStart+len(c.Reason)]
number := byte(c.Code % errorCodeModulo) // error code modulo 100 number := byte(c.Code % errorCodeModulo) // error code modulo 100
class := byte(c.Code / errorCodeModulo) // hundred digit class := byte(c.Code / errorCodeModulo) // hundred digit
value[errorCodeClassByte] = class value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number value[errorCodeNumberByte] = number
value = append(value, c.Reason...) copy(value[errorCodeReasonStart:], c.Reason)
m.Add(AttrErrorCode, value) m.Add(AttrErrorCode, value)
return nil return nil
} }
// GetFrom decodes ERROR-CODE from m. // GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid.
func (c *ErrorCodeAttribute) GetFrom(m *Message) error { func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
v, err := m.Get(AttrErrorCode) v, err := m.Get(AttrErrorCode)
if err != nil { if err != nil {
@@ -39,10 +46,9 @@ func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
class = uint16(v[errorCodeClassByte]) class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte]) number = uint16(v[errorCodeNumberByte])
code = int(class*errorCodeModulo + number) code = int(class*errorCodeModulo + number)
reason = v[errorCodeReasonStart:]
) )
c.Code = ErrorCode(code) c.Code = ErrorCode(code)
c.Reason = reason c.Reason = v[errorCodeReasonStart:]
return nil return nil
} }
@@ -51,7 +57,7 @@ type ErrorCode int
// ErrNoDefaultReason means that default reason for provided error code // ErrNoDefaultReason means that default reason for provided error code
// is not defined in RFC. // is not defined in RFC.
const ErrNoDefaultReason Error = "No default reason for ErrorCode" var ErrNoDefaultReason = errors.New("No default reason for ErrorCode")
// AddTo adds ERROR-CODE with default reason to m. If there // AddTo adds ERROR-CODE with default reason to m. If there
// is no default reason, returns ErrNoDefaultReason. // is no default reason, returns ErrNoDefaultReason.

View File

@@ -1,6 +1,7 @@
package stun package stun
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
) )
@@ -31,7 +32,7 @@ func isZeros(p net.IP) bool {
} }
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}. // ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
const ErrBadIPLength Error = "invalid length if IP value" var ErrBadIPLength = errors.New("invalid length if IP value")
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength // AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
// if len(a.IP) is invalid. // if len(a.IP) is invalid.

View File

@@ -1,6 +1,7 @@
package stun package stun
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
) )
@@ -118,8 +119,6 @@ func (t AttrType) String() string {
// don't understand, but cannot successfully process a message if it // don't understand, but cannot successfully process a message if it
// contains comprehension-required attributes that are not // contains comprehension-required attributes that are not
// understood. // understood.
//
// TODO(ar): Decide to use pointer or non-pointer RawAttribute.
type RawAttribute struct { type RawAttribute struct {
Type AttrType Type AttrType
Length uint16 // ignored while encoding Length uint16 // ignored while encoding
@@ -157,7 +156,7 @@ func (a RawAttribute) String() string {
// ErrAttributeNotFound means that attribute with provided attribute // ErrAttributeNotFound means that attribute with provided attribute
// type does not exist in message. // type does not exist in message.
const ErrAttributeNotFound Error = "Attribute not found" var ErrAttributeNotFound = errors.New("Attribute not found")
// Get returns byte slice that represents attribute value, // Get returns byte slice that represents attribute value,
// if there is no attribute with such type, // if there is no attribute with such type,

View File

@@ -212,6 +212,19 @@ func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) {
} }
} }
func BenchmarkErrorCodeAttribute_GetFrom(b *testing.B) {
m := New()
b.ReportAllocs()
a := &ErrorCodeAttribute{
Code: 404,
Reason: []byte("not found!"),
}
a.AddTo(m)
for i := 0; i < b.N; i++ {
a.GetFrom(m)
}
}
func TestMessage_AddErrorCode(t *testing.T) { func TestMessage_AddErrorCode(t *testing.T) {
m := New() m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
@@ -243,3 +256,11 @@ func TestMessage_AddErrorCode(t *testing.T) {
t.Error("bad reason", string(errCodeAttr.Reason)) t.Error("bad reason", string(errCodeAttr.Reason))
} }
} }
func BenchmarkMessage_GetNotFound(b *testing.B) {
m := New()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m.Get(AttrRealm)
}
}

View File

@@ -1,14 +1,5 @@
package stun package stun
// Error is error type for constant errors in stun package.
//
// See http://dave.cheney.net/2016/04/07/constant-errors for more info.
type Error string
func (e Error) Error() string {
return string(e)
}
// DecodeErr records an error and place when it is occurred. // DecodeErr records an error and place when it is occurred.
type DecodeErr struct { type DecodeErr struct {
Place DecodeErrPlace Place DecodeErrPlace

View File

@@ -1,25 +0,0 @@
package stun
import "testing"
func TestDecodeErr(t *testing.T) {
err := newDecodeErr("parent", "children", "message")
if !err.IsPlace(DecodeErrPlace{Parent: "parent", Children: "children"}) {
t.Error("isPlace test failed")
}
if !err.IsPlaceParent("parent") {
t.Error("parent test failed")
}
if !err.IsPlaceChildren("children") {
t.Error("children test failed")
}
if err.Error() != "BadFormat for parent/children: message" {
t.Error("bad Error string")
}
}
func TestError_Error(t *testing.T) {
if Error("error").Error() != "error" {
t.Error("bad Error string")
}
}

View File

@@ -24,6 +24,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
@@ -261,11 +262,9 @@ func (m *Message) ReadFrom(r io.Reader) (int64, error) {
return int64(n), m.Decode() return int64(n), m.Decode()
} }
const (
// ErrUnexpectedHeaderEOF means that there were not enough bytes in // ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header. // m.Raw to read header.
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header" var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
)
// Decode decodes m.Raw into m. // Decode decodes m.Raw into m.
func (m *Message) Decode() error { func (m *Message) Decode() error {

View File

@@ -15,19 +15,12 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
) )
func bUint16(v uint16) string { func bUint16(v uint16) string {
return fmt.Sprintf("0b%016s", strconv.FormatUint(uint64(v), 2)) return fmt.Sprintf("0b%016s", strconv.FormatUint(uint64(v), 2))
} }
func init() {
log.SetLevel(log.DebugLevel)
}
func (m *Message) reader() *bytes.Reader { func (m *Message) reader() *bytes.Reader {
return bytes.NewReader(m.Raw) return bytes.NewReader(m.Raw)
} }
@@ -185,7 +178,7 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
mDecoded := New() mDecoded := New()
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2])) _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
switch e := errors.Cause(err).(type) { switch e := err.(type) {
case *DecodeErr: case *DecodeErr:
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) { if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) {
t.Error(e, "bad place") t.Error(e, "bad place")
@@ -213,7 +206,7 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) {
bin.PutUint16(m.Raw[2:4], 5) // rewrite to bad length bin.PutUint16(m.Raw[2:4], 5) // rewrite to bad length
mDecoded := New() mDecoded := New()
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5])) _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
switch e := errors.Cause(err).(type) { switch e := err.(type) {
case *DecodeErr: case *DecodeErr:
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) { if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) {
t.Error(e, "bad place") t.Error(e, "bad place")
@@ -232,7 +225,7 @@ func (r unexpectedEOFReader) Read(b []byte) (int, error) {
func TestMessage_ReadFromError(t *testing.T) { func TestMessage_ReadFromError(t *testing.T) {
mDecoded := New() mDecoded := New()
_, err := mDecoded.ReadFrom(unexpectedEOFReader{}) _, err := mDecoded.ReadFrom(unexpectedEOFReader{})
if errors.Cause(err) != io.ErrUnexpectedEOF { if err != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF) t.Error(err, "should be", io.ErrUnexpectedEOF)
} }
} }