all: use ReadFrom and WriteTo instead of Put and Get

This commit is contained in:
Aleksandr Razumov
2016-05-02 22:27:53 +03:00
parent 2d4a2df781
commit c735425ae8
5 changed files with 162 additions and 114 deletions

View File

@@ -7,6 +7,7 @@ import (
"testing"
//"encoding/binary"
"bytes"
"encoding/binary"
"strings"
)
@@ -20,7 +21,7 @@ func TestMessage_AddSoftware(t *testing.T) {
m2 := AcquireMessage()
defer ReleaseMessage(m2)
if err := m2.Get(m.buf.B); err != nil {
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
vRead := m.GetSoftware()
@@ -44,7 +45,7 @@ func TestMessage_AddSoftwareBytes(t *testing.T) {
m2 := AcquireMessage()
defer ReleaseMessage(m2)
if err := m2.Get(m.buf.B); err != nil {
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
vRead := m.GetSoftware()
@@ -145,7 +146,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
mRes := AcquireMessage()
defer ReleaseMessage(mRes)
binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21)
if err = mRes.Get(m.buf.B); err != nil {
if _, err = mRes.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
t.Fatal(err)
}
_, _, err = m.GetXORMappedAddress()
@@ -169,7 +170,7 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
mRes := AcquireMessage()
defer ReleaseMessage(mRes)
if err = mRes.Get(m.buf.B); err != nil {
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
ip, port, err := m.GetXORMappedAddress()

View File

@@ -50,9 +50,8 @@ func TestClientSend(t *testing.T) {
m.AddSoftware("cydev/stun alpha")
m.WriteHeader()
timeout := 100 * time.Millisecond
recvBuf := make([]byte, 1024)
for i := 0; i < 9; i++ {
_, err := conn.Write(m.buf.B)
_, err := m.WriteTo(conn)
if err != nil {
t.Fatal(err)
}
@@ -62,18 +61,18 @@ func TestClientSend(t *testing.T) {
if timeout < 1600*time.Millisecond {
timeout *= 2
}
n, err := conn.Read(recvBuf)
var (
ip net.IP
port int
)
if err == nil {
mRec := AcquireMessage()
if err = mRec.Get(recvBuf[:n]); err != nil {
if _, err = mRec.ReadFrom(conn); err != nil {
t.Error(err)
}
log.Println(mRec)
log.Println(mRec.Attributes)
log.Println("got message:", mRec)
log.Println("got attributes:", mRec.Attributes)
log.Println("got error:", err)
if mRec.TransactionID != m.TransactionID {
t.Error("TransactionID missmatch")
}

View File

@@ -26,7 +26,6 @@ func discover(c *cli.Context) error {
m.AddSoftware("cydev/stun alpha")
m.WriteHeader()
timeout := 100 * time.Millisecond
recvBuf := make([]byte, 1024)
for i := 0; i < 9; i++ {
_, err := m.WriteTo(conn)
if err != nil {
@@ -38,14 +37,13 @@ func discover(c *cli.Context) error {
if timeout < 1600*time.Millisecond {
timeout *= 2
}
n, err := conn.Read(recvBuf)
var (
ip net.IP
port int
)
if err == nil {
mRec := stun.AcquireMessage()
if err = mRec.Get(recvBuf[:n]); err != nil {
if _, err = mRec.ReadFrom(conn); err != nil {
return err
}
if mRec.TransactionID != m.TransactionID {

97
stun.go
View File

@@ -29,6 +29,8 @@ import (
"strconv"
"sync"
"bytes"
"github.com/cydev/buffer"
"github.com/pkg/errors"
)
@@ -269,70 +271,73 @@ func (m *Message) WriteHeader() {
binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20))
}
// WriteTo implements WriterTo
func (m Message) WriteTo(w io.Writer) (int, error) {
return w.Write(m.buf.B)
// WriteTo implements WriterTo.
func (m Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.buf.B)
return int64(n), err
}
// Put encodes message into buf. If len(buf) is not enough, it panics.
func (m Message) Put(buf []byte) {
// encoding header
binary.BigEndian.PutUint16(buf[0:2], m.Type.Value())
binary.BigEndian.PutUint32(buf[4:8], magicCookie)
copy(buf[8:messageHeaderSize], m.TransactionID[:])
offset := messageHeaderSize
// encoding attributes
for _, a := range m.Attributes {
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Type.Value())
offset += 2
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Length)
offset += 2
copy(buf[offset:offset+len(a.Value)], a.Value[:])
offset += len(a.Value)
func AcquireFields(message Message) *Message {
m := AcquireMessage()
copy(m.TransactionID[:], message.TransactionID[:])
m.Type = message.Type
for _, a := range message.Attributes {
m.Add(a.Type, a.Value)
}
// writing length as size, in bytes, not including the 20-byte STUN header.
binary.BigEndian.PutUint16(buf[2:4], uint16(offset-20))
m.WriteHeader()
return m
}
// Get decodes message from byte slice and return error if any.
func (m *Message) reader() *bytes.Reader {
return bytes.NewReader(m.buf.B)
}
// ReadFrom implements ReaderFrom. Decodes message and return error if any.
//
// Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength.
// Any error is unrecoverable, but message could be partially decoded.
//
// ErrUnexpectedEOF means that there were not enough bytes to read header or
// value and indicates possible data loss.
func (m *Message) Get(buf []byte) error {
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
m.mustWrite()
var (
read int64
n int
err error
)
if len(buf) < messageHeaderSize {
msg := fmt.Sprintln(len(buf), "<", messageHeaderSize, "message")
return errors.Wrap(io.ErrUnexpectedEOF, msg)
// reading packet into buffer
buf := make([]byte, MaxPacketSize)
if n, err = r.Read(buf); err != nil {
return int64(n), err
}
// writing to internal buffer
m.buf.B = m.buf.B[:0]
m.buf.Write(buf[:n])
read += int64(n)
// buf is now internal buffer
buf = m.buf.B[:n]
// decoding message header
m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes
tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes
m.Length = uint32(tLength)
l := int(tLength)
cookie := binary.BigEndian.Uint32(buf[4:8])
copy(m.TransactionID[:], buf[8:messageHeaderSize])
if cookie != magicCookie {
return errors.Wrap(ErrInvalidMessageLength, "missmatch")
return read, ErrInvalidMagicCookie
}
offset := messageHeaderSize
mLength := int(m.Length)
if (mLength + offset) > len(buf) {
msg := fmt.Sprintln((mLength + offset), ">", len(buf))
return errors.Wrap(ErrInvalidMessageLength, msg)
}
for (mLength + messageHeaderSize - offset) > 0 {
buf = buf[messageHeaderSize:messageHeaderSize+l]
offset := 0
for offset < l {
b := buf[offset:]
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header")
return errors.Wrap(io.ErrUnexpectedEOF, msg)
return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
}
a := Attribute{}
@@ -340,20 +345,20 @@ func (m *Message) Get(buf []byte) error {
t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes
a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes
a.Type = AttrType(t)
l := int(a.Length)
aL := int(a.Length)
// reading value
b = b[attributeHeaderSize:] // slicing again to simplify value read
if len(b) < l { // checking size
msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), l)
return errors.Wrap(io.ErrUnexpectedEOF, msg)
if len(b) < aL { // checking size
msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), aL)
return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
}
a.Value = b[:l]
a.Value = b[:aL]
m.Attributes = append(m.Attributes, a)
offset += l + attributeHeaderSize
offset += aL + attributeHeaderSize
}
return nil
return int64(l + messageHeaderSize), nil
}
const (
@@ -361,6 +366,12 @@ const (
messageHeaderSize = 20
)
const (
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
MaxPacketSize = 2048
)
// AttrType is attribute type.
type AttrType uint16

View File

@@ -9,6 +9,8 @@ import (
"io"
"strings"
"bytes"
log "github.com/Sirupsen/logrus"
"github.com/pkg/errors"
)
@@ -41,8 +43,8 @@ func TestMessageBuffer(t *testing.T) {
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
mDecoded := &Message{}
if err := mDecoded.Get(m.buf.B); err != nil {
mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
t.Error(err)
}
if !mDecoded.Equal(*m) {
@@ -123,51 +125,71 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
}
}
func TestMessage_PutGet(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
messageAttributes := Attributes{
messageAttribute,
//func TestMessage_PutGet(t *testing.T) {
// mType := MessageType{Method: MethodBinding, Class: ClassRequest}
// messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
// messageAttributes := Attributes{
// messageAttribute,
// }
// m := Message{
// Type: mType,
// Length: 6,
// TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
// Attributes: messageAttributes,
// }
// buf := make([]byte, 128)
// m.Put(buf)
// mDecoded := Message{}
// if err := mDecoded.Get(buf); err != nil {
// t.Error(err)
// }
// if mDecoded.Type != m.Type {
// t.Error("incorrect type")
// }
// if mDecoded.Length != m.Length {
// t.Error("incorrect length")
// }
// if mDecoded.TransactionID != m.TransactionID {
// t.Error("incorrect transaction ID")
// }
// aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
// if !aDecoded.Equal(messageAttribute) {
// t.Error(aDecoded, "!=", messageAttribute)
// }
//}
func TestMessage_WriteTo(t *testing.T) {
m := AcquireMessage()
defer ReleaseMessage(m)
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil {
t.Fatal(err)
}
m := Message{
Type: mType,
Length: 6,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Attributes: messageAttributes,
}
buf := make([]byte, 128)
m.Put(buf)
mDecoded := Message{}
if err := mDecoded.Get(buf); err != nil {
mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err)
}
if mDecoded.Type != m.Type {
t.Error("incorrect type")
}
if mDecoded.Length != m.Length {
t.Error("incorrect length")
}
if mDecoded.TransactionID != m.TransactionID {
t.Error("incorrect transaction ID")
}
aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
if !aDecoded.Equal(messageAttribute) {
t.Error(aDecoded, "!=", messageAttribute)
if !mDecoded.Equal(*m) {
t.Error(mDecoded, "!", m)
}
}
func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20)
mDecoded := Message{}
if err := mDecoded.Get(buf); err == nil {
mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error")
}
}
func TestMessage_LengthLessHeaderSize(t *testing.T) {
buf := make([]byte, 8)
mDecoded := Message{}
if err := mDecoded.Get(buf); err == nil {
mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error")
}
}
@@ -175,19 +197,17 @@ func TestMessage_LengthLessHeaderSize(t *testing.T) {
func TestMessage_BadLength(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
messageAttributes := Attributes{
messageAttribute,
}
m := Message{
m := AcquireFields(Message{
Type: mType,
Length: 4,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Attributes: messageAttributes,
}
buf := make([]byte, 128)
m.Put(buf)
mDecoded := Message{}
if err := mDecoded.Get(buf[:20+3]); err == nil {
Attributes: []Attribute{messageAttribute},
})
b := new(bytes.Buffer)
m.WriteTo(b)
b.Truncate(20 + 3)
mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(b); err == nil {
t.Error("should error")
}
}
@@ -198,16 +218,17 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
messageAttributes := Attributes{
messageAttribute,
}
m := Message{
m := AcquireFields(Message{
Type: mType,
TransactionID: NewTransactionID(),
Attributes: messageAttributes,
}
buf := make([]byte, 128)
m.Put(buf)
})
b := new(bytes.Buffer)
m.WriteTo(b)
buf := b.Bytes()
mDecoded := AcquireMessage()
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
mDecoded := Message{}
err := mDecoded.Get(buf[:20+2])
_, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+2]))
if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF)
}
@@ -215,20 +236,37 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
func TestMessage_AttrSizeLessThanLength(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
messageAttribute := Attribute{Length: 4,
Value: []byte{1, 2, 3, 4}, Type: 0x1,
}
messageAttributes := Attributes{
messageAttribute,
}
m := Message{
m := AcquireFields(Message{
Type: mType,
TransactionID: NewTransactionID(),
Attributes: messageAttributes,
})
b := new(bytes.Buffer)
m.WriteTo(b)
buf := b.Bytes()
binary.BigEndian.PutUint16(buf[2:4], 5) // rewrite to bad length
mDecoded := AcquireMessage()
_, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+5]))
if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF)
}
buf := make([]byte, 128)
m.Put(buf)
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
mDecoded := Message{}
err := mDecoded.Get(buf[:20+5])
}
type unexpectedEOFReader struct{}
func (_ unexpectedEOFReader) Read(b []byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
func TestMessage_ReadFromError(t *testing.T) {
mDecoded := AcquireMessage()
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF)
}
@@ -249,10 +287,11 @@ func BenchmarkMessage_Put(b *testing.B) {
Length: 0,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
}
buf := make([]byte, 20)
buf := new(bytes.Buffer)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m.Put(buf)
m.WriteTo(buf)
buf.Reset()
}
}