diff --git a/attributes_test.go b/attributes_test.go index 55552cd..f56712a 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -7,6 +7,7 @@ import ( "testing" //"encoding/binary" + "bytes" "encoding/binary" "strings" ) @@ -20,7 +21,7 @@ func TestMessage_AddSoftware(t *testing.T) { m2 := AcquireMessage() defer ReleaseMessage(m2) - if err := m2.Get(m.buf.B); err != nil { + if _, err := m2.ReadFrom(m.reader()); err != nil { t.Error(err) } vRead := m.GetSoftware() @@ -44,7 +45,7 @@ func TestMessage_AddSoftwareBytes(t *testing.T) { m2 := AcquireMessage() defer ReleaseMessage(m2) - if err := m2.Get(m.buf.B); err != nil { + if _, err := m2.ReadFrom(m.reader()); err != nil { t.Error(err) } vRead := m.GetSoftware() @@ -145,7 +146,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) { mRes := AcquireMessage() defer ReleaseMessage(mRes) binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21) - if err = mRes.Get(m.buf.B); err != nil { + if _, err = mRes.ReadFrom(bytes.NewReader(m.buf.B)); err != nil { t.Fatal(err) } _, _, err = m.GetXORMappedAddress() @@ -169,7 +170,7 @@ func TestMessage_AddXORMappedAddress(t *testing.T) { mRes := AcquireMessage() defer ReleaseMessage(mRes) - if err = mRes.Get(m.buf.B); err != nil { + if _, err = mRes.ReadFrom(m.reader()); err != nil { t.Fatal(err) } ip, port, err := m.GetXORMappedAddress() diff --git a/client_test.go b/client_test.go index 27c0f42..e7eea0b 100644 --- a/client_test.go +++ b/client_test.go @@ -50,9 +50,8 @@ func TestClientSend(t *testing.T) { m.AddSoftware("cydev/stun alpha") m.WriteHeader() timeout := 100 * time.Millisecond - recvBuf := make([]byte, 1024) for i := 0; i < 9; i++ { - _, err := conn.Write(m.buf.B) + _, err := m.WriteTo(conn) if err != nil { t.Fatal(err) } @@ -62,18 +61,18 @@ func TestClientSend(t *testing.T) { if timeout < 1600*time.Millisecond { timeout *= 2 } - n, err := conn.Read(recvBuf) var ( ip net.IP port int ) if err == nil { mRec := AcquireMessage() - if err = mRec.Get(recvBuf[:n]); err != nil { + if _, err = mRec.ReadFrom(conn); err != nil { t.Error(err) } - log.Println(mRec) - log.Println(mRec.Attributes) + log.Println("got message:", mRec) + log.Println("got attributes:", mRec.Attributes) + log.Println("got error:", err) if mRec.TransactionID != m.TransactionID { t.Error("TransactionID missmatch") } diff --git a/stun-cli/cli.go b/stun-cli/cli.go index cd80239..4167af3 100644 --- a/stun-cli/cli.go +++ b/stun-cli/cli.go @@ -26,7 +26,6 @@ func discover(c *cli.Context) error { m.AddSoftware("cydev/stun alpha") m.WriteHeader() timeout := 100 * time.Millisecond - recvBuf := make([]byte, 1024) for i := 0; i < 9; i++ { _, err := m.WriteTo(conn) if err != nil { @@ -38,14 +37,13 @@ func discover(c *cli.Context) error { if timeout < 1600*time.Millisecond { timeout *= 2 } - n, err := conn.Read(recvBuf) var ( ip net.IP port int ) if err == nil { mRec := stun.AcquireMessage() - if err = mRec.Get(recvBuf[:n]); err != nil { + if _, err = mRec.ReadFrom(conn); err != nil { return err } if mRec.TransactionID != m.TransactionID { diff --git a/stun.go b/stun.go index 4d1d4cd..87c60aa 100644 --- a/stun.go +++ b/stun.go @@ -29,6 +29,8 @@ import ( "strconv" "sync" + "bytes" + "github.com/cydev/buffer" "github.com/pkg/errors" ) @@ -269,70 +271,73 @@ func (m *Message) WriteHeader() { binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20)) } -// WriteTo implements WriterTo -func (m Message) WriteTo(w io.Writer) (int, error) { - return w.Write(m.buf.B) +// WriteTo implements WriterTo. +func (m Message) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.buf.B) + return int64(n), err } -// Put encodes message into buf. If len(buf) is not enough, it panics. -func (m Message) Put(buf []byte) { - // encoding header - binary.BigEndian.PutUint16(buf[0:2], m.Type.Value()) - binary.BigEndian.PutUint32(buf[4:8], magicCookie) - copy(buf[8:messageHeaderSize], m.TransactionID[:]) - offset := messageHeaderSize - // encoding attributes - for _, a := range m.Attributes { - binary.BigEndian.PutUint16(buf[offset:offset+2], a.Type.Value()) - offset += 2 - binary.BigEndian.PutUint16(buf[offset:offset+2], a.Length) - offset += 2 - copy(buf[offset:offset+len(a.Value)], a.Value[:]) - offset += len(a.Value) +func AcquireFields(message Message) *Message { + m := AcquireMessage() + copy(m.TransactionID[:], message.TransactionID[:]) + m.Type = message.Type + for _, a := range message.Attributes { + m.Add(a.Type, a.Value) } - // writing length as size, in bytes, not including the 20-byte STUN header. - binary.BigEndian.PutUint16(buf[2:4], uint16(offset-20)) + m.WriteHeader() + return m } -// Get decodes message from byte slice and return error if any. +func (m *Message) reader() *bytes.Reader { + return bytes.NewReader(m.buf.B) +} + +// ReadFrom implements ReaderFrom. Decodes message and return error if any. // // Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength. // Any error is unrecoverable, but message could be partially decoded. // // ErrUnexpectedEOF means that there were not enough bytes to read header or -// value and indicates possible data loss. -func (m *Message) Get(buf []byte) error { +func (m *Message) ReadFrom(r io.Reader) (int64, error) { m.mustWrite() + var ( + read int64 + n int + err error + ) - if len(buf) < messageHeaderSize { - msg := fmt.Sprintln(len(buf), "<", messageHeaderSize, "message") - return errors.Wrap(io.ErrUnexpectedEOF, msg) + // reading packet into buffer + buf := make([]byte, MaxPacketSize) + if n, err = r.Read(buf); err != nil { + return int64(n), err } + // writing to internal buffer + m.buf.B = m.buf.B[:0] + m.buf.Write(buf[:n]) + read += int64(n) + // buf is now internal buffer + buf = m.buf.B[:n] // decoding message header m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes m.Length = uint32(tLength) + l := int(tLength) cookie := binary.BigEndian.Uint32(buf[4:8]) copy(m.TransactionID[:], buf[8:messageHeaderSize]) if cookie != magicCookie { - return errors.Wrap(ErrInvalidMessageLength, "missmatch") + return read, ErrInvalidMagicCookie } - offset := messageHeaderSize - mLength := int(m.Length) - if (mLength + offset) > len(buf) { - msg := fmt.Sprintln((mLength + offset), ">", len(buf)) - return errors.Wrap(ErrInvalidMessageLength, msg) - } - - for (mLength + messageHeaderSize - offset) > 0 { + buf = buf[messageHeaderSize:messageHeaderSize+l] + offset := 0 + for offset < l { b := buf[offset:] // checking that we have enough bytes to read header if len(b) < attributeHeaderSize { msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header") - return errors.Wrap(io.ErrUnexpectedEOF, msg) + return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg) } a := Attribute{} @@ -340,20 +345,20 @@ func (m *Message) Get(buf []byte) error { t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes a.Type = AttrType(t) - l := int(a.Length) + aL := int(a.Length) // reading value b = b[attributeHeaderSize:] // slicing again to simplify value read - if len(b) < l { // checking size - msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), l) - return errors.Wrap(io.ErrUnexpectedEOF, msg) + if len(b) < aL { // checking size + msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), aL) + return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg) } - a.Value = b[:l] + a.Value = b[:aL] m.Attributes = append(m.Attributes, a) - offset += l + attributeHeaderSize + offset += aL + attributeHeaderSize } - return nil + return int64(l + messageHeaderSize), nil } const ( @@ -361,6 +366,12 @@ const ( messageHeaderSize = 20 ) +const ( + // MaxPacketSize is maximum size of UDP packet that is processable in + // this package for STUN message. + MaxPacketSize = 2048 +) + // AttrType is attribute type. type AttrType uint16 diff --git a/stun_test.go b/stun_test.go index 009e30c..3e7e800 100644 --- a/stun_test.go +++ b/stun_test.go @@ -9,6 +9,8 @@ import ( "io" "strings" + "bytes" + log "github.com/Sirupsen/logrus" "github.com/pkg/errors" ) @@ -41,8 +43,8 @@ func TestMessageBuffer(t *testing.T) { m.TransactionID = NewTransactionID() m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.WriteHeader() - mDecoded := &Message{} - if err := mDecoded.Get(m.buf.B); err != nil { + mDecoded := AcquireMessage() + if _, err := mDecoded.ReadFrom(bytes.NewReader(m.buf.B)); err != nil { t.Error(err) } if !mDecoded.Equal(*m) { @@ -123,51 +125,71 @@ func TestMessageType_ReadWriteValue(t *testing.T) { } } -func TestMessage_PutGet(t *testing.T) { - mType := MessageType{Method: MethodBinding, Class: ClassRequest} - messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} - messageAttributes := Attributes{ - messageAttribute, +//func TestMessage_PutGet(t *testing.T) { +// mType := MessageType{Method: MethodBinding, Class: ClassRequest} +// messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} +// messageAttributes := Attributes{ +// messageAttribute, +// } +// m := Message{ +// Type: mType, +// Length: 6, +// TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, +// Attributes: messageAttributes, +// } +// buf := make([]byte, 128) +// m.Put(buf) +// mDecoded := Message{} +// if err := mDecoded.Get(buf); err != nil { +// t.Error(err) +// } +// if mDecoded.Type != m.Type { +// t.Error("incorrect type") +// } +// if mDecoded.Length != m.Length { +// t.Error("incorrect length") +// } +// if mDecoded.TransactionID != m.TransactionID { +// t.Error("incorrect transaction ID") +// } +// aDecoded := mDecoded.Attributes.Get(messageAttribute.Type) +// if !aDecoded.Equal(messageAttribute) { +// t.Error(aDecoded, "!=", messageAttribute) +// } +//} + +func TestMessage_WriteTo(t *testing.T) { + m := AcquireMessage() + defer ReleaseMessage(m) + m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} + m.TransactionID = NewTransactionID() + m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) + m.WriteHeader() + buf := new(bytes.Buffer) + if _, err := m.WriteTo(buf); err != nil { + t.Fatal(err) } - m := Message{ - Type: mType, - Length: 6, - TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Attributes: messageAttributes, - } - buf := make([]byte, 128) - m.Put(buf) - mDecoded := Message{} - if err := mDecoded.Get(buf); err != nil { + mDecoded := AcquireMessage() + if _, err := mDecoded.ReadFrom(buf); err != nil { t.Error(err) } - if mDecoded.Type != m.Type { - t.Error("incorrect type") - } - if mDecoded.Length != m.Length { - t.Error("incorrect length") - } - if mDecoded.TransactionID != m.TransactionID { - t.Error("incorrect transaction ID") - } - aDecoded := mDecoded.Attributes.Get(messageAttribute.Type) - if !aDecoded.Equal(messageAttribute) { - t.Error(aDecoded, "!=", messageAttribute) + if !mDecoded.Equal(*m) { + t.Error(mDecoded, "!", m) } } func TestMessage_Cookie(t *testing.T) { buf := make([]byte, 20) - mDecoded := Message{} - if err := mDecoded.Get(buf); err == nil { + mDecoded := AcquireMessage() + if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { t.Error("should error") } } func TestMessage_LengthLessHeaderSize(t *testing.T) { buf := make([]byte, 8) - mDecoded := Message{} - if err := mDecoded.Get(buf); err == nil { + mDecoded := AcquireMessage() + if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { t.Error("should error") } } @@ -175,19 +197,17 @@ func TestMessage_LengthLessHeaderSize(t *testing.T) { func TestMessage_BadLength(t *testing.T) { mType := MessageType{Method: MethodBinding, Class: ClassRequest} messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} - messageAttributes := Attributes{ - messageAttribute, - } - m := Message{ + m := AcquireFields(Message{ Type: mType, Length: 4, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Attributes: messageAttributes, - } - buf := make([]byte, 128) - m.Put(buf) - mDecoded := Message{} - if err := mDecoded.Get(buf[:20+3]); err == nil { + Attributes: []Attribute{messageAttribute}, + }) + b := new(bytes.Buffer) + m.WriteTo(b) + b.Truncate(20 + 3) + mDecoded := AcquireMessage() + if _, err := mDecoded.ReadFrom(b); err == nil { t.Error("should error") } } @@ -198,16 +218,17 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) { messageAttributes := Attributes{ messageAttribute, } - m := Message{ + m := AcquireFields(Message{ Type: mType, TransactionID: NewTransactionID(), Attributes: messageAttributes, - } - buf := make([]byte, 128) - m.Put(buf) + }) + b := new(bytes.Buffer) + m.WriteTo(b) + buf := b.Bytes() + mDecoded := AcquireMessage() binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length - mDecoded := Message{} - err := mDecoded.Get(buf[:20+2]) + _, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+2])) if errors.Cause(err) != io.ErrUnexpectedEOF { t.Error(err, "should be", io.ErrUnexpectedEOF) } @@ -215,20 +236,37 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) { func TestMessage_AttrSizeLessThanLength(t *testing.T) { mType := MessageType{Method: MethodBinding, Class: ClassRequest} - messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1} + messageAttribute := Attribute{Length: 4, + Value: []byte{1, 2, 3, 4}, Type: 0x1, + } messageAttributes := Attributes{ messageAttribute, } - m := Message{ + m := AcquireFields(Message{ Type: mType, TransactionID: NewTransactionID(), Attributes: messageAttributes, + }) + b := new(bytes.Buffer) + m.WriteTo(b) + buf := b.Bytes() + binary.BigEndian.PutUint16(buf[2:4], 5) // rewrite to bad length + mDecoded := AcquireMessage() + _, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+5])) + if errors.Cause(err) != io.ErrUnexpectedEOF { + t.Error(err, "should be", io.ErrUnexpectedEOF) } - buf := make([]byte, 128) - m.Put(buf) - binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length - mDecoded := Message{} - err := mDecoded.Get(buf[:20+5]) +} + +type unexpectedEOFReader struct{} + +func (_ unexpectedEOFReader) Read(b []byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +func TestMessage_ReadFromError(t *testing.T) { + mDecoded := AcquireMessage() + _, err := mDecoded.ReadFrom(unexpectedEOFReader{}) if errors.Cause(err) != io.ErrUnexpectedEOF { t.Error(err, "should be", io.ErrUnexpectedEOF) } @@ -249,10 +287,11 @@ func BenchmarkMessage_Put(b *testing.B) { Length: 0, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, } - buf := make([]byte, 20) + buf := new(bytes.Buffer) b.ReportAllocs() for i := 0; i < b.N; i++ { - m.Put(buf) + m.WriteTo(buf) + buf.Reset() } }