all: more refactoring for refactor god

This commit is contained in:
Aleksandr Razumov
2017-02-04 14:56:22 -08:00
parent 6f411f0187
commit c17c2351ea
5 changed files with 88 additions and 35 deletions

View File

@@ -14,7 +14,7 @@ type AttrWriter interface {
// AttrEncoder wraps Encode method. // AttrEncoder wraps Encode method.
type AttrEncoder interface { type AttrEncoder interface {
Encode(w AttrWriter, m *Message) error Encode(b []byte, m *Message) (AttrType, []byte, error)
} }
// Attributes is list of message attributes. // Attributes is list of message attributes.
@@ -200,21 +200,24 @@ func (b *bufEncoder) AddRaw(t AttrType, v []byte) {
} }
// Set sets the value of attribute if it presents. // Set sets the value of attribute if it presents.
func (m *Message) Set(t AttrType, v AttrEncoder) error { func (m *Message) Set(a AttrEncoder) error {
var ( var (
a bufEncoder v []byte
err error
t AttrType
) )
if err := v.Encode(&a, m); err != nil { t, v, err = a.Encode(v, m)
return err
}
buf, err := m.getAttrValue(a.Type)
if err != nil { if err != nil {
return err return err
} }
if len(a.Value) != len(buf) { buf, err := m.getAttrValue(t)
if err != nil {
return err
}
if len(v) != len(buf) {
return ErrBadSetLength return ErrBadSetLength
} }
copy(buf, a.Value) copy(buf, v)
return nil return nil
} }
@@ -246,8 +249,7 @@ type XORMappedAddress struct {
} }
// Encode implements AttrEncoder. // Encode implements AttrEncoder.
// TODO(ar): fix signature. func (a *XORMappedAddress) Encode(buf []byte, m *Message) (AttrType, []byte, error) {
func (a *XORMappedAddress) Encode(m *Message) ([]byte, error) {
// X-Port is computed by taking the mapped port in host byte order, // X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and // XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order. // then the converting the result to network byte order.
@@ -267,7 +269,8 @@ func (a *XORMappedAddress) Encode(m *Message) ([]byte, error) {
binary.BigEndian.PutUint16(value[0:2], uint16(family)) binary.BigEndian.PutUint16(value[0:2], uint16(family))
binary.BigEndian.PutUint16(value[2:4], uint16(port)) binary.BigEndian.PutUint16(value[2:4], uint16(port))
xorBytes(value[4:4+len(ip)], ip, xorValue) xorBytes(value[4:4+len(ip)], ip, xorValue)
return value, nil buf = append(buf, value...)
return AttrXORMappedAddress, buf, nil
} }
// Decode implements AttrDecoder. // Decode implements AttrDecoder.
@@ -436,9 +439,8 @@ func NewSoftware(software string) *Software {
} }
// Encode implements AttrEncoder. // Encode implements AttrEncoder.
func (s *Software) Encode(w AttrWriter, m *Message) error { func (s *Software) Encode(b []byte, m *Message) (AttrType, []byte, error) {
w.AddRaw(AttrSoftware, s.Raw) return AttrSoftware, append(b, s.Raw...), nil
return nil
} }
// Decode implements AttrDecoder. // Decode implements AttrDecoder.

View File

@@ -16,7 +16,9 @@ func TestMessage_AddSoftware(t *testing.T) {
m.AddRaw(AttrSoftware, []byte(v)) m.AddRaw(AttrSoftware, []byte(v))
m.WriteHeader() m.WriteHeader()
m2 := New() m2 := &Message{
Raw: make([]byte, 0, 256),
}
if _, err := m2.ReadFrom(m.reader()); err != nil { if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err) t.Error(err)
} }

52
stun.go
View File

@@ -117,7 +117,7 @@ func NewTransactionID() (b [transactionIDSize]byte) {
// defaults for pool. // defaults for pool.
const ( const (
defaultMessageBufferCapacity = 416 defaultMessageBufferCapacity = 120
) )
// New returns *Message with allocated Raw. // New returns *Message with allocated Raw.
@@ -150,7 +150,22 @@ func (m *Message) grow(v int) {
// Add adds AttrEncoder to message, calling Encode method. // Add adds AttrEncoder to message, calling Encode method.
func (m *Message) Add(a AttrEncoder) error { func (m *Message) Add(a AttrEncoder) error {
return a.Encode(m, m) var (
err error
t AttrType
)
initial := len(m.Raw)
for i := 0; i < attributeHeaderSize; i++ {
m.Raw = append(m.Raw, 0)
}
start := len(m.Raw)
t, m.Raw, err = a.Encode(m.Raw, m)
if err != nil {
m.Raw = m.Raw[:initial]
return err
}
m.AddRaw(t, m.Raw[start:])
return nil
} }
// AddRaw appends new attribute to message. Not goroutine-safe. // AddRaw appends new attribute to message. Not goroutine-safe.
@@ -158,6 +173,7 @@ func (m *Message) Add(a AttrEncoder) error {
// Value of attribute is copied to internal buffer so // Value of attribute is copied to internal buffer so
// it is safe to reuse v. // it is safe to reuse v.
func (m *Message) AddRaw(t AttrType, v []byte) { func (m *Message) AddRaw(t AttrType, v []byte) {
// CPU: suboptimal;
// allocating memory for TLV (type-length-value), where // allocating memory for TLV (type-length-value), where
// type-length is attribute header. // type-length is attribute header.
// m.buf.B[0:20] is reserved by header // m.buf.B[0:20] is reserved by header
@@ -278,36 +294,28 @@ func (m *Message) Append(v []byte) []byte {
return append(v, m.Raw...) return append(v, m.Raw...)
} }
// Bytes returns message raw value.
// Deprecated: use m.Raw.
func (m *Message) Bytes() []byte {
return m.Raw
}
// WriteToConn writes a packet with message to addr, using c. // WriteToConn writes a packet with message to addr, using c.
// Deprecated; non-idiomatic. // Deprecated; non-idiomatic.
func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) { func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) {
return c.WriteTo(m.Raw, addr) return c.WriteTo(m.Raw, addr)
} }
// ReadFrom implements ReaderFrom. Decodes message and return error if any. // ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
// //
// Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength. // Can return *DecodeErr while decoding.
// Any error is unrecoverable, but message could be partially decoded.
//
// ErrUnexpectedEOF means that there were not enough bytes to read header or
// Deprecated: use Decode.
func (m *Message) ReadFrom(r io.Reader) (int64, error) { func (m *Message) ReadFrom(r io.Reader) (int64, error) {
tBuf := make([]byte, 0, MaxPacketSize) tBuf := m.Raw[:cap(m.Raw)]
var ( var (
n int n int
err error err error
) )
if n, err = r.Read(tBuf[:MaxPacketSize]); err != nil { if n, err = r.Read(tBuf); err != nil {
return int64(n), err return int64(n), err
} }
n, err = m.ReadBytes(tBuf[:n]) m.Raw = tBuf[:n]
return int64(n), err return int64(n), m.Decode()
} }
func newAttrDecodeErr(children, message string) *DecodeErr { func newAttrDecodeErr(children, message string) *DecodeErr {
@@ -324,10 +332,18 @@ func IsMessage(b []byte) bool {
binary.BigEndian.Uint32(b[4:8]) == magicCookie binary.BigEndian.Uint32(b[4:8]) == magicCookie
} }
var (
// ErrUnexpectedHeaderEOF
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
)
// Decode decodes m.Raw into m. // Decode decodes m.Raw into m.
func (m *Message) Decode() error { func (m *Message) Decode() error {
// decoding message header // decoding message header
buf := m.Raw buf := m.Raw
if len(buf) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var ( var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes

View File

@@ -15,6 +15,8 @@ import (
"strings" "strings"
"testing" "testing"
"net"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -602,3 +604,34 @@ func BenchmarkNewTransactionID(b *testing.B) {
m.TransactionID = NewTransactionID() m.TransactionID = NewTransactionID()
} }
} }
func BenchmarkMessageFull(b *testing.B) {
b.ReportAllocs()
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
}
for i := 0; i < b.N; i++ {
m.Add(addr)
m.Add(s)
m.WriteAttributes()
m.Reset()
}
}
func BenchmarkMessageFullHardcore(b *testing.B) {
b.ReportAllocs()
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
}
t, v, _ := addr.Encode(nil, m)
for i := 0; i < b.N; i++ {
m.AddRaw(AttrSoftware, s.Raw)
m.AddRaw(t, v)
m.WriteAttributes()
m.Reset()
}
}

View File

@@ -23,7 +23,7 @@ func BenchmarkBasicProcess(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
res.Reset() res.Reset()
req.Reset() req.Reset()
if err := basicProcess(addr, m.Bytes(), req, res); err != nil { if err := basicProcess(addr, m.Raw, req, res); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }