all: use ReadFrom and WriteTo instead of Put and Get

This commit is contained in:
Aleksandr Razumov
2016-05-02 22:27:53 +03:00
parent 2d4a2df781
commit c735425ae8
5 changed files with 162 additions and 114 deletions

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
//"encoding/binary" //"encoding/binary"
"bytes"
"encoding/binary" "encoding/binary"
"strings" "strings"
) )
@@ -20,7 +21,7 @@ func TestMessage_AddSoftware(t *testing.T) {
m2 := AcquireMessage() m2 := AcquireMessage()
defer ReleaseMessage(m2) defer ReleaseMessage(m2)
if err := m2.Get(m.buf.B); err != nil { if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err) t.Error(err)
} }
vRead := m.GetSoftware() vRead := m.GetSoftware()
@@ -44,7 +45,7 @@ func TestMessage_AddSoftwareBytes(t *testing.T) {
m2 := AcquireMessage() m2 := AcquireMessage()
defer ReleaseMessage(m2) defer ReleaseMessage(m2)
if err := m2.Get(m.buf.B); err != nil { if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err) t.Error(err)
} }
vRead := m.GetSoftware() vRead := m.GetSoftware()
@@ -145,7 +146,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
mRes := AcquireMessage() mRes := AcquireMessage()
defer ReleaseMessage(mRes) defer ReleaseMessage(mRes)
binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21) binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21)
if err = mRes.Get(m.buf.B); err != nil { if _, err = mRes.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, err = m.GetXORMappedAddress() _, _, err = m.GetXORMappedAddress()
@@ -169,7 +170,7 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
mRes := AcquireMessage() mRes := AcquireMessage()
defer ReleaseMessage(mRes) defer ReleaseMessage(mRes)
if err = mRes.Get(m.buf.B); err != nil { if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
ip, port, err := m.GetXORMappedAddress() ip, port, err := m.GetXORMappedAddress()

View File

@@ -50,9 +50,8 @@ func TestClientSend(t *testing.T) {
m.AddSoftware("cydev/stun alpha") m.AddSoftware("cydev/stun alpha")
m.WriteHeader() m.WriteHeader()
timeout := 100 * time.Millisecond timeout := 100 * time.Millisecond
recvBuf := make([]byte, 1024)
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
_, err := conn.Write(m.buf.B) _, err := m.WriteTo(conn)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -62,18 +61,18 @@ func TestClientSend(t *testing.T) {
if timeout < 1600*time.Millisecond { if timeout < 1600*time.Millisecond {
timeout *= 2 timeout *= 2
} }
n, err := conn.Read(recvBuf)
var ( var (
ip net.IP ip net.IP
port int port int
) )
if err == nil { if err == nil {
mRec := AcquireMessage() mRec := AcquireMessage()
if err = mRec.Get(recvBuf[:n]); err != nil { if _, err = mRec.ReadFrom(conn); err != nil {
t.Error(err) t.Error(err)
} }
log.Println(mRec) log.Println("got message:", mRec)
log.Println(mRec.Attributes) log.Println("got attributes:", mRec.Attributes)
log.Println("got error:", err)
if mRec.TransactionID != m.TransactionID { if mRec.TransactionID != m.TransactionID {
t.Error("TransactionID missmatch") t.Error("TransactionID missmatch")
} }

View File

@@ -26,7 +26,6 @@ func discover(c *cli.Context) error {
m.AddSoftware("cydev/stun alpha") m.AddSoftware("cydev/stun alpha")
m.WriteHeader() m.WriteHeader()
timeout := 100 * time.Millisecond timeout := 100 * time.Millisecond
recvBuf := make([]byte, 1024)
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
_, err := m.WriteTo(conn) _, err := m.WriteTo(conn)
if err != nil { if err != nil {
@@ -38,14 +37,13 @@ func discover(c *cli.Context) error {
if timeout < 1600*time.Millisecond { if timeout < 1600*time.Millisecond {
timeout *= 2 timeout *= 2
} }
n, err := conn.Read(recvBuf)
var ( var (
ip net.IP ip net.IP
port int port int
) )
if err == nil { if err == nil {
mRec := stun.AcquireMessage() mRec := stun.AcquireMessage()
if err = mRec.Get(recvBuf[:n]); err != nil { if _, err = mRec.ReadFrom(conn); err != nil {
return err return err
} }
if mRec.TransactionID != m.TransactionID { if mRec.TransactionID != m.TransactionID {

97
stun.go
View File

@@ -29,6 +29,8 @@ import (
"strconv" "strconv"
"sync" "sync"
"bytes"
"github.com/cydev/buffer" "github.com/cydev/buffer"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -269,70 +271,73 @@ func (m *Message) WriteHeader() {
binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20)) binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20))
} }
// WriteTo implements WriterTo // WriteTo implements WriterTo.
func (m Message) WriteTo(w io.Writer) (int, error) { func (m Message) WriteTo(w io.Writer) (int64, error) {
return w.Write(m.buf.B) n, err := w.Write(m.buf.B)
return int64(n), err
} }
// Put encodes message into buf. If len(buf) is not enough, it panics. func AcquireFields(message Message) *Message {
func (m Message) Put(buf []byte) { m := AcquireMessage()
// encoding header copy(m.TransactionID[:], message.TransactionID[:])
binary.BigEndian.PutUint16(buf[0:2], m.Type.Value()) m.Type = message.Type
binary.BigEndian.PutUint32(buf[4:8], magicCookie) for _, a := range message.Attributes {
copy(buf[8:messageHeaderSize], m.TransactionID[:]) m.Add(a.Type, a.Value)
offset := messageHeaderSize
// encoding attributes
for _, a := range m.Attributes {
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Type.Value())
offset += 2
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Length)
offset += 2
copy(buf[offset:offset+len(a.Value)], a.Value[:])
offset += len(a.Value)
} }
// writing length as size, in bytes, not including the 20-byte STUN header. m.WriteHeader()
binary.BigEndian.PutUint16(buf[2:4], uint16(offset-20)) return m
} }
// Get decodes message from byte slice and return error if any. func (m *Message) reader() *bytes.Reader {
return bytes.NewReader(m.buf.B)
}
// ReadFrom implements ReaderFrom. Decodes message and return error if any.
// //
// Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength. // Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength.
// Any error is unrecoverable, but message could be partially decoded. // Any error is unrecoverable, but message could be partially decoded.
// //
// ErrUnexpectedEOF means that there were not enough bytes to read header or // ErrUnexpectedEOF means that there were not enough bytes to read header or
// value and indicates possible data loss. func (m *Message) ReadFrom(r io.Reader) (int64, error) {
func (m *Message) Get(buf []byte) error {
m.mustWrite() m.mustWrite()
var (
read int64
n int
err error
)
if len(buf) < messageHeaderSize { // reading packet into buffer
msg := fmt.Sprintln(len(buf), "<", messageHeaderSize, "message") buf := make([]byte, MaxPacketSize)
return errors.Wrap(io.ErrUnexpectedEOF, msg) if n, err = r.Read(buf); err != nil {
return int64(n), err
} }
// writing to internal buffer
m.buf.B = m.buf.B[:0]
m.buf.Write(buf[:n])
read += int64(n)
// buf is now internal buffer
buf = m.buf.B[:n]
// decoding message header // decoding message header
m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes
tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes
m.Length = uint32(tLength) m.Length = uint32(tLength)
l := int(tLength)
cookie := binary.BigEndian.Uint32(buf[4:8]) cookie := binary.BigEndian.Uint32(buf[4:8])
copy(m.TransactionID[:], buf[8:messageHeaderSize]) copy(m.TransactionID[:], buf[8:messageHeaderSize])
if cookie != magicCookie { if cookie != magicCookie {
return errors.Wrap(ErrInvalidMessageLength, "missmatch") return read, ErrInvalidMagicCookie
} }
offset := messageHeaderSize buf = buf[messageHeaderSize:messageHeaderSize+l]
mLength := int(m.Length) offset := 0
if (mLength + offset) > len(buf) { for offset < l {
msg := fmt.Sprintln((mLength + offset), ">", len(buf))
return errors.Wrap(ErrInvalidMessageLength, msg)
}
for (mLength + messageHeaderSize - offset) > 0 {
b := buf[offset:] b := buf[offset:]
// checking that we have enough bytes to read header // checking that we have enough bytes to read header
if len(b) < attributeHeaderSize { if len(b) < attributeHeaderSize {
msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header") msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header")
return errors.Wrap(io.ErrUnexpectedEOF, msg) return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
} }
a := Attribute{} a := Attribute{}
@@ -340,20 +345,20 @@ func (m *Message) Get(buf []byte) error {
t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes
a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes
a.Type = AttrType(t) a.Type = AttrType(t)
l := int(a.Length) aL := int(a.Length)
// reading value // reading value
b = b[attributeHeaderSize:] // slicing again to simplify value read b = b[attributeHeaderSize:] // slicing again to simplify value read
if len(b) < l { // checking size if len(b) < aL { // checking size
msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), l) msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), aL)
return errors.Wrap(io.ErrUnexpectedEOF, msg) return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
} }
a.Value = b[:l] a.Value = b[:aL]
m.Attributes = append(m.Attributes, a) m.Attributes = append(m.Attributes, a)
offset += l + attributeHeaderSize offset += aL + attributeHeaderSize
} }
return nil return int64(l + messageHeaderSize), nil
} }
const ( const (
@@ -361,6 +366,12 @@ const (
messageHeaderSize = 20 messageHeaderSize = 20
) )
const (
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
MaxPacketSize = 2048
)
// AttrType is attribute type. // AttrType is attribute type.
type AttrType uint16 type AttrType uint16

View File

@@ -9,6 +9,8 @@ import (
"io" "io"
"strings" "strings"
"bytes"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -41,8 +43,8 @@ func TestMessageBuffer(t *testing.T) {
m.TransactionID = NewTransactionID() m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader() m.WriteHeader()
mDecoded := &Message{} mDecoded := AcquireMessage()
if err := mDecoded.Get(m.buf.B); err != nil { if _, err := mDecoded.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
t.Error(err) t.Error(err)
} }
if !mDecoded.Equal(*m) { if !mDecoded.Equal(*m) {
@@ -123,51 +125,71 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
} }
} }
func TestMessage_PutGet(t *testing.T) { //func TestMessage_PutGet(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest} // mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} // messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
messageAttributes := Attributes{ // messageAttributes := Attributes{
messageAttribute, // messageAttribute,
// }
// m := Message{
// Type: mType,
// Length: 6,
// TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
// Attributes: messageAttributes,
// }
// buf := make([]byte, 128)
// m.Put(buf)
// mDecoded := Message{}
// if err := mDecoded.Get(buf); err != nil {
// t.Error(err)
// }
// if mDecoded.Type != m.Type {
// t.Error("incorrect type")
// }
// if mDecoded.Length != m.Length {
// t.Error("incorrect length")
// }
// if mDecoded.TransactionID != m.TransactionID {
// t.Error("incorrect transaction ID")
// }
// aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
// if !aDecoded.Equal(messageAttribute) {
// t.Error(aDecoded, "!=", messageAttribute)
// }
//}
func TestMessage_WriteTo(t *testing.T) {
m := AcquireMessage()
defer ReleaseMessage(m)
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil {
t.Fatal(err)
} }
m := Message{ mDecoded := AcquireMessage()
Type: mType, if _, err := mDecoded.ReadFrom(buf); err != nil {
Length: 6,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Attributes: messageAttributes,
}
buf := make([]byte, 128)
m.Put(buf)
mDecoded := Message{}
if err := mDecoded.Get(buf); err != nil {
t.Error(err) t.Error(err)
} }
if mDecoded.Type != m.Type { if !mDecoded.Equal(*m) {
t.Error("incorrect type") t.Error(mDecoded, "!", m)
}
if mDecoded.Length != m.Length {
t.Error("incorrect length")
}
if mDecoded.TransactionID != m.TransactionID {
t.Error("incorrect transaction ID")
}
aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
if !aDecoded.Equal(messageAttribute) {
t.Error(aDecoded, "!=", messageAttribute)
} }
} }
func TestMessage_Cookie(t *testing.T) { func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20) buf := make([]byte, 20)
mDecoded := Message{} mDecoded := AcquireMessage()
if err := mDecoded.Get(buf); err == nil { if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error") t.Error("should error")
} }
} }
func TestMessage_LengthLessHeaderSize(t *testing.T) { func TestMessage_LengthLessHeaderSize(t *testing.T) {
buf := make([]byte, 8) buf := make([]byte, 8)
mDecoded := Message{} mDecoded := AcquireMessage()
if err := mDecoded.Get(buf); err == nil { if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error") t.Error("should error")
} }
} }
@@ -175,19 +197,17 @@ func TestMessage_LengthLessHeaderSize(t *testing.T) {
func TestMessage_BadLength(t *testing.T) { func TestMessage_BadLength(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest} mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
messageAttributes := Attributes{ m := AcquireFields(Message{
messageAttribute,
}
m := Message{
Type: mType, Type: mType,
Length: 4, Length: 4,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
Attributes: messageAttributes, Attributes: []Attribute{messageAttribute},
} })
buf := make([]byte, 128) b := new(bytes.Buffer)
m.Put(buf) m.WriteTo(b)
mDecoded := Message{} b.Truncate(20 + 3)
if err := mDecoded.Get(buf[:20+3]); err == nil { mDecoded := AcquireMessage()
if _, err := mDecoded.ReadFrom(b); err == nil {
t.Error("should error") t.Error("should error")
} }
} }
@@ -198,16 +218,17 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
messageAttributes := Attributes{ messageAttributes := Attributes{
messageAttribute, messageAttribute,
} }
m := Message{ m := AcquireFields(Message{
Type: mType, Type: mType,
TransactionID: NewTransactionID(), TransactionID: NewTransactionID(),
Attributes: messageAttributes, Attributes: messageAttributes,
} })
buf := make([]byte, 128) b := new(bytes.Buffer)
m.Put(buf) m.WriteTo(b)
buf := b.Bytes()
mDecoded := AcquireMessage()
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
mDecoded := Message{} _, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+2]))
err := mDecoded.Get(buf[:20+2])
if errors.Cause(err) != io.ErrUnexpectedEOF { if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF) t.Error(err, "should be", io.ErrUnexpectedEOF)
} }
@@ -215,20 +236,37 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
func TestMessage_AttrSizeLessThanLength(t *testing.T) { func TestMessage_AttrSizeLessThanLength(t *testing.T) {
mType := MessageType{Method: MethodBinding, Class: ClassRequest} mType := MessageType{Method: MethodBinding, Class: ClassRequest}
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} messageAttribute := Attribute{Length: 4,
Value: []byte{1, 2, 3, 4}, Type: 0x1,
}
messageAttributes := Attributes{ messageAttributes := Attributes{
messageAttribute, messageAttribute,
} }
m := Message{ m := AcquireFields(Message{
Type: mType, Type: mType,
TransactionID: NewTransactionID(), TransactionID: NewTransactionID(),
Attributes: messageAttributes, Attributes: messageAttributes,
})
b := new(bytes.Buffer)
m.WriteTo(b)
buf := b.Bytes()
binary.BigEndian.PutUint16(buf[2:4], 5) // rewrite to bad length
mDecoded := AcquireMessage()
_, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+5]))
if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF)
} }
buf := make([]byte, 128) }
m.Put(buf)
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length type unexpectedEOFReader struct{}
mDecoded := Message{}
err := mDecoded.Get(buf[:20+5]) func (_ unexpectedEOFReader) Read(b []byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
func TestMessage_ReadFromError(t *testing.T) {
mDecoded := AcquireMessage()
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
if errors.Cause(err) != io.ErrUnexpectedEOF { if errors.Cause(err) != io.ErrUnexpectedEOF {
t.Error(err, "should be", io.ErrUnexpectedEOF) t.Error(err, "should be", io.ErrUnexpectedEOF)
} }
@@ -249,10 +287,11 @@ func BenchmarkMessage_Put(b *testing.B) {
Length: 0, Length: 0,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
} }
buf := make([]byte, 20) buf := new(bytes.Buffer)
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.Put(buf) m.WriteTo(buf)
buf.Reset()
} }
} }