mirror of
https://github.com/pion/stun.git
synced 2025-10-07 08:40:53 +08:00
all: use ReadFrom and WriteTo instead of Put and Get
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
//"encoding/binary"
|
//"encoding/binary"
|
||||||
|
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -20,7 +21,7 @@ func TestMessage_AddSoftware(t *testing.T) {
|
|||||||
|
|
||||||
m2 := AcquireMessage()
|
m2 := AcquireMessage()
|
||||||
defer ReleaseMessage(m2)
|
defer ReleaseMessage(m2)
|
||||||
if err := m2.Get(m.buf.B); err != nil {
|
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
vRead := m.GetSoftware()
|
vRead := m.GetSoftware()
|
||||||
@@ -44,7 +45,7 @@ func TestMessage_AddSoftwareBytes(t *testing.T) {
|
|||||||
|
|
||||||
m2 := AcquireMessage()
|
m2 := AcquireMessage()
|
||||||
defer ReleaseMessage(m2)
|
defer ReleaseMessage(m2)
|
||||||
if err := m2.Get(m.buf.B); err != nil {
|
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
vRead := m.GetSoftware()
|
vRead := m.GetSoftware()
|
||||||
@@ -145,7 +146,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
|
|||||||
mRes := AcquireMessage()
|
mRes := AcquireMessage()
|
||||||
defer ReleaseMessage(mRes)
|
defer ReleaseMessage(mRes)
|
||||||
binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21)
|
binary.BigEndian.PutUint16(m.buf.B[20+4:20+4+2], 0x21)
|
||||||
if err = mRes.Get(m.buf.B); err != nil {
|
if _, err = mRes.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
_, _, err = m.GetXORMappedAddress()
|
_, _, err = m.GetXORMappedAddress()
|
||||||
@@ -169,7 +170,7 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
|
|||||||
|
|
||||||
mRes := AcquireMessage()
|
mRes := AcquireMessage()
|
||||||
defer ReleaseMessage(mRes)
|
defer ReleaseMessage(mRes)
|
||||||
if err = mRes.Get(m.buf.B); err != nil {
|
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ip, port, err := m.GetXORMappedAddress()
|
ip, port, err := m.GetXORMappedAddress()
|
||||||
|
@@ -50,9 +50,8 @@ func TestClientSend(t *testing.T) {
|
|||||||
m.AddSoftware("cydev/stun alpha")
|
m.AddSoftware("cydev/stun alpha")
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
timeout := 100 * time.Millisecond
|
timeout := 100 * time.Millisecond
|
||||||
recvBuf := make([]byte, 1024)
|
|
||||||
for i := 0; i < 9; i++ {
|
for i := 0; i < 9; i++ {
|
||||||
_, err := conn.Write(m.buf.B)
|
_, err := m.WriteTo(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -62,18 +61,18 @@ func TestClientSend(t *testing.T) {
|
|||||||
if timeout < 1600*time.Millisecond {
|
if timeout < 1600*time.Millisecond {
|
||||||
timeout *= 2
|
timeout *= 2
|
||||||
}
|
}
|
||||||
n, err := conn.Read(recvBuf)
|
|
||||||
var (
|
var (
|
||||||
ip net.IP
|
ip net.IP
|
||||||
port int
|
port int
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
mRec := AcquireMessage()
|
mRec := AcquireMessage()
|
||||||
if err = mRec.Get(recvBuf[:n]); err != nil {
|
if _, err = mRec.ReadFrom(conn); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
log.Println(mRec)
|
log.Println("got message:", mRec)
|
||||||
log.Println(mRec.Attributes)
|
log.Println("got attributes:", mRec.Attributes)
|
||||||
|
log.Println("got error:", err)
|
||||||
if mRec.TransactionID != m.TransactionID {
|
if mRec.TransactionID != m.TransactionID {
|
||||||
t.Error("TransactionID missmatch")
|
t.Error("TransactionID missmatch")
|
||||||
}
|
}
|
||||||
|
@@ -26,7 +26,6 @@ func discover(c *cli.Context) error {
|
|||||||
m.AddSoftware("cydev/stun alpha")
|
m.AddSoftware("cydev/stun alpha")
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
timeout := 100 * time.Millisecond
|
timeout := 100 * time.Millisecond
|
||||||
recvBuf := make([]byte, 1024)
|
|
||||||
for i := 0; i < 9; i++ {
|
for i := 0; i < 9; i++ {
|
||||||
_, err := m.WriteTo(conn)
|
_, err := m.WriteTo(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,14 +37,13 @@ func discover(c *cli.Context) error {
|
|||||||
if timeout < 1600*time.Millisecond {
|
if timeout < 1600*time.Millisecond {
|
||||||
timeout *= 2
|
timeout *= 2
|
||||||
}
|
}
|
||||||
n, err := conn.Read(recvBuf)
|
|
||||||
var (
|
var (
|
||||||
ip net.IP
|
ip net.IP
|
||||||
port int
|
port int
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
mRec := stun.AcquireMessage()
|
mRec := stun.AcquireMessage()
|
||||||
if err = mRec.Get(recvBuf[:n]); err != nil {
|
if _, err = mRec.ReadFrom(conn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if mRec.TransactionID != m.TransactionID {
|
if mRec.TransactionID != m.TransactionID {
|
||||||
|
97
stun.go
97
stun.go
@@ -29,6 +29,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"bytes"
|
||||||
|
|
||||||
"github.com/cydev/buffer"
|
"github.com/cydev/buffer"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@@ -269,70 +271,73 @@ func (m *Message) WriteHeader() {
|
|||||||
binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20))
|
binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf)-20))
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteTo implements WriterTo
|
// WriteTo implements WriterTo.
|
||||||
func (m Message) WriteTo(w io.Writer) (int, error) {
|
func (m Message) WriteTo(w io.Writer) (int64, error) {
|
||||||
return w.Write(m.buf.B)
|
n, err := w.Write(m.buf.B)
|
||||||
|
return int64(n), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put encodes message into buf. If len(buf) is not enough, it panics.
|
func AcquireFields(message Message) *Message {
|
||||||
func (m Message) Put(buf []byte) {
|
m := AcquireMessage()
|
||||||
// encoding header
|
copy(m.TransactionID[:], message.TransactionID[:])
|
||||||
binary.BigEndian.PutUint16(buf[0:2], m.Type.Value())
|
m.Type = message.Type
|
||||||
binary.BigEndian.PutUint32(buf[4:8], magicCookie)
|
for _, a := range message.Attributes {
|
||||||
copy(buf[8:messageHeaderSize], m.TransactionID[:])
|
m.Add(a.Type, a.Value)
|
||||||
offset := messageHeaderSize
|
|
||||||
// encoding attributes
|
|
||||||
for _, a := range m.Attributes {
|
|
||||||
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Type.Value())
|
|
||||||
offset += 2
|
|
||||||
binary.BigEndian.PutUint16(buf[offset:offset+2], a.Length)
|
|
||||||
offset += 2
|
|
||||||
copy(buf[offset:offset+len(a.Value)], a.Value[:])
|
|
||||||
offset += len(a.Value)
|
|
||||||
}
|
}
|
||||||
// writing length as size, in bytes, not including the 20-byte STUN header.
|
m.WriteHeader()
|
||||||
binary.BigEndian.PutUint16(buf[2:4], uint16(offset-20))
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get decodes message from byte slice and return error if any.
|
func (m *Message) reader() *bytes.Reader {
|
||||||
|
return bytes.NewReader(m.buf.B)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrom implements ReaderFrom. Decodes message and return error if any.
|
||||||
//
|
//
|
||||||
// Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength.
|
// Can return ErrUnexpectedEOF, ErrInvalidMagicCookie, ErrInvalidMessageLength.
|
||||||
// Any error is unrecoverable, but message could be partially decoded.
|
// Any error is unrecoverable, but message could be partially decoded.
|
||||||
//
|
//
|
||||||
// ErrUnexpectedEOF means that there were not enough bytes to read header or
|
// ErrUnexpectedEOF means that there were not enough bytes to read header or
|
||||||
// value and indicates possible data loss.
|
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
||||||
func (m *Message) Get(buf []byte) error {
|
|
||||||
m.mustWrite()
|
m.mustWrite()
|
||||||
|
var (
|
||||||
|
read int64
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
if len(buf) < messageHeaderSize {
|
// reading packet into buffer
|
||||||
msg := fmt.Sprintln(len(buf), "<", messageHeaderSize, "message")
|
buf := make([]byte, MaxPacketSize)
|
||||||
return errors.Wrap(io.ErrUnexpectedEOF, msg)
|
if n, err = r.Read(buf); err != nil {
|
||||||
|
return int64(n), err
|
||||||
}
|
}
|
||||||
|
// writing to internal buffer
|
||||||
|
m.buf.B = m.buf.B[:0]
|
||||||
|
m.buf.Write(buf[:n])
|
||||||
|
read += int64(n)
|
||||||
|
// buf is now internal buffer
|
||||||
|
buf = m.buf.B[:n]
|
||||||
|
|
||||||
// decoding message header
|
// decoding message header
|
||||||
m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes
|
m.Type.ReadValue(binary.BigEndian.Uint16(buf[0:2])) // first 2 bytes
|
||||||
tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes
|
tLength := binary.BigEndian.Uint16(buf[2:4]) // second 2 bytes
|
||||||
m.Length = uint32(tLength)
|
m.Length = uint32(tLength)
|
||||||
|
l := int(tLength)
|
||||||
cookie := binary.BigEndian.Uint32(buf[4:8])
|
cookie := binary.BigEndian.Uint32(buf[4:8])
|
||||||
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
||||||
|
|
||||||
if cookie != magicCookie {
|
if cookie != magicCookie {
|
||||||
return errors.Wrap(ErrInvalidMessageLength, "missmatch")
|
return read, ErrInvalidMagicCookie
|
||||||
}
|
}
|
||||||
|
|
||||||
offset := messageHeaderSize
|
buf = buf[messageHeaderSize:messageHeaderSize+l]
|
||||||
mLength := int(m.Length)
|
offset := 0
|
||||||
if (mLength + offset) > len(buf) {
|
for offset < l {
|
||||||
msg := fmt.Sprintln((mLength + offset), ">", len(buf))
|
|
||||||
return errors.Wrap(ErrInvalidMessageLength, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
for (mLength + messageHeaderSize - offset) > 0 {
|
|
||||||
b := buf[offset:]
|
b := buf[offset:]
|
||||||
// checking that we have enough bytes to read header
|
// checking that we have enough bytes to read header
|
||||||
if len(b) < attributeHeaderSize {
|
if len(b) < attributeHeaderSize {
|
||||||
msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header")
|
msg := fmt.Sprintln(len(buf), "<", attributeHeaderSize, "header")
|
||||||
return errors.Wrap(io.ErrUnexpectedEOF, msg)
|
return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
|
||||||
}
|
}
|
||||||
a := Attribute{}
|
a := Attribute{}
|
||||||
|
|
||||||
@@ -340,20 +345,20 @@ func (m *Message) Get(buf []byte) error {
|
|||||||
t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes
|
t := binary.BigEndian.Uint16(b[0:2]) // first 2 bytes
|
||||||
a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes
|
a.Length = binary.BigEndian.Uint16(b[2:4]) // second 2 bytes
|
||||||
a.Type = AttrType(t)
|
a.Type = AttrType(t)
|
||||||
l := int(a.Length)
|
aL := int(a.Length)
|
||||||
|
|
||||||
// reading value
|
// reading value
|
||||||
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
||||||
if len(b) < l { // checking size
|
if len(b) < aL { // checking size
|
||||||
msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), l)
|
msg := fmt.Sprintf("attr len(b) == %d < %d", len(b), aL)
|
||||||
return errors.Wrap(io.ErrUnexpectedEOF, msg)
|
return int64(offset) + read, errors.Wrap(io.ErrUnexpectedEOF, msg)
|
||||||
}
|
}
|
||||||
a.Value = b[:l]
|
a.Value = b[:aL]
|
||||||
|
|
||||||
m.Attributes = append(m.Attributes, a)
|
m.Attributes = append(m.Attributes, a)
|
||||||
offset += l + attributeHeaderSize
|
offset += aL + attributeHeaderSize
|
||||||
}
|
}
|
||||||
return nil
|
return int64(l + messageHeaderSize), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -361,6 +366,12 @@ const (
|
|||||||
messageHeaderSize = 20
|
messageHeaderSize = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MaxPacketSize is maximum size of UDP packet that is processable in
|
||||||
|
// this package for STUN message.
|
||||||
|
MaxPacketSize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
// AttrType is attribute type.
|
// AttrType is attribute type.
|
||||||
type AttrType uint16
|
type AttrType uint16
|
||||||
|
|
||||||
|
155
stun_test.go
155
stun_test.go
@@ -9,6 +9,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"bytes"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
@@ -41,8 +43,8 @@ func TestMessageBuffer(t *testing.T) {
|
|||||||
m.TransactionID = NewTransactionID()
|
m.TransactionID = NewTransactionID()
|
||||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
mDecoded := &Message{}
|
mDecoded := AcquireMessage()
|
||||||
if err := mDecoded.Get(m.buf.B); err != nil {
|
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.buf.B)); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if !mDecoded.Equal(*m) {
|
if !mDecoded.Equal(*m) {
|
||||||
@@ -123,51 +125,71 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_PutGet(t *testing.T) {
|
//func TestMessage_PutGet(t *testing.T) {
|
||||||
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
// mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
|
// messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
|
||||||
messageAttributes := Attributes{
|
// messageAttributes := Attributes{
|
||||||
messageAttribute,
|
// messageAttribute,
|
||||||
|
// }
|
||||||
|
// m := Message{
|
||||||
|
// Type: mType,
|
||||||
|
// Length: 6,
|
||||||
|
// TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||||
|
// Attributes: messageAttributes,
|
||||||
|
// }
|
||||||
|
// buf := make([]byte, 128)
|
||||||
|
// m.Put(buf)
|
||||||
|
// mDecoded := Message{}
|
||||||
|
// if err := mDecoded.Get(buf); err != nil {
|
||||||
|
// t.Error(err)
|
||||||
|
// }
|
||||||
|
// if mDecoded.Type != m.Type {
|
||||||
|
// t.Error("incorrect type")
|
||||||
|
// }
|
||||||
|
// if mDecoded.Length != m.Length {
|
||||||
|
// t.Error("incorrect length")
|
||||||
|
// }
|
||||||
|
// if mDecoded.TransactionID != m.TransactionID {
|
||||||
|
// t.Error("incorrect transaction ID")
|
||||||
|
// }
|
||||||
|
// aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
|
||||||
|
// if !aDecoded.Equal(messageAttribute) {
|
||||||
|
// t.Error(aDecoded, "!=", messageAttribute)
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
func TestMessage_WriteTo(t *testing.T) {
|
||||||
|
m := AcquireMessage()
|
||||||
|
defer ReleaseMessage(m)
|
||||||
|
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
|
m.TransactionID = NewTransactionID()
|
||||||
|
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||||
|
m.WriteHeader()
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if _, err := m.WriteTo(buf); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
m := Message{
|
mDecoded := AcquireMessage()
|
||||||
Type: mType,
|
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
||||||
Length: 6,
|
|
||||||
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
|
||||||
Attributes: messageAttributes,
|
|
||||||
}
|
|
||||||
buf := make([]byte, 128)
|
|
||||||
m.Put(buf)
|
|
||||||
mDecoded := Message{}
|
|
||||||
if err := mDecoded.Get(buf); err != nil {
|
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if mDecoded.Type != m.Type {
|
if !mDecoded.Equal(*m) {
|
||||||
t.Error("incorrect type")
|
t.Error(mDecoded, "!", m)
|
||||||
}
|
|
||||||
if mDecoded.Length != m.Length {
|
|
||||||
t.Error("incorrect length")
|
|
||||||
}
|
|
||||||
if mDecoded.TransactionID != m.TransactionID {
|
|
||||||
t.Error("incorrect transaction ID")
|
|
||||||
}
|
|
||||||
aDecoded := mDecoded.Attributes.Get(messageAttribute.Type)
|
|
||||||
if !aDecoded.Equal(messageAttribute) {
|
|
||||||
t.Error(aDecoded, "!=", messageAttribute)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_Cookie(t *testing.T) {
|
func TestMessage_Cookie(t *testing.T) {
|
||||||
buf := make([]byte, 20)
|
buf := make([]byte, 20)
|
||||||
mDecoded := Message{}
|
mDecoded := AcquireMessage()
|
||||||
if err := mDecoded.Get(buf); err == nil {
|
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessage_LengthLessHeaderSize(t *testing.T) {
|
func TestMessage_LengthLessHeaderSize(t *testing.T) {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
mDecoded := Message{}
|
mDecoded := AcquireMessage()
|
||||||
if err := mDecoded.Get(buf); err == nil {
|
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,19 +197,17 @@ func TestMessage_LengthLessHeaderSize(t *testing.T) {
|
|||||||
func TestMessage_BadLength(t *testing.T) {
|
func TestMessage_BadLength(t *testing.T) {
|
||||||
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
|
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
|
||||||
messageAttributes := Attributes{
|
m := AcquireFields(Message{
|
||||||
messageAttribute,
|
|
||||||
}
|
|
||||||
m := Message{
|
|
||||||
Type: mType,
|
Type: mType,
|
||||||
Length: 4,
|
Length: 4,
|
||||||
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||||
Attributes: messageAttributes,
|
Attributes: []Attribute{messageAttribute},
|
||||||
}
|
})
|
||||||
buf := make([]byte, 128)
|
b := new(bytes.Buffer)
|
||||||
m.Put(buf)
|
m.WriteTo(b)
|
||||||
mDecoded := Message{}
|
b.Truncate(20 + 3)
|
||||||
if err := mDecoded.Get(buf[:20+3]); err == nil {
|
mDecoded := AcquireMessage()
|
||||||
|
if _, err := mDecoded.ReadFrom(b); err == nil {
|
||||||
t.Error("should error")
|
t.Error("should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,16 +218,17 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
|
|||||||
messageAttributes := Attributes{
|
messageAttributes := Attributes{
|
||||||
messageAttribute,
|
messageAttribute,
|
||||||
}
|
}
|
||||||
m := Message{
|
m := AcquireFields(Message{
|
||||||
Type: mType,
|
Type: mType,
|
||||||
TransactionID: NewTransactionID(),
|
TransactionID: NewTransactionID(),
|
||||||
Attributes: messageAttributes,
|
Attributes: messageAttributes,
|
||||||
}
|
})
|
||||||
buf := make([]byte, 128)
|
b := new(bytes.Buffer)
|
||||||
m.Put(buf)
|
m.WriteTo(b)
|
||||||
|
buf := b.Bytes()
|
||||||
|
mDecoded := AcquireMessage()
|
||||||
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
|
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
|
||||||
mDecoded := Message{}
|
_, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+2]))
|
||||||
err := mDecoded.Get(buf[:20+2])
|
|
||||||
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
||||||
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
@@ -215,20 +236,37 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
|
|||||||
|
|
||||||
func TestMessage_AttrSizeLessThanLength(t *testing.T) {
|
func TestMessage_AttrSizeLessThanLength(t *testing.T) {
|
||||||
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
mType := MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
messageAttribute := Attribute{Length: 2, Value: []byte{1, 2}, Type: 0x1}
|
messageAttribute := Attribute{Length: 4,
|
||||||
|
Value: []byte{1, 2, 3, 4}, Type: 0x1,
|
||||||
|
}
|
||||||
messageAttributes := Attributes{
|
messageAttributes := Attributes{
|
||||||
messageAttribute,
|
messageAttribute,
|
||||||
}
|
}
|
||||||
m := Message{
|
m := AcquireFields(Message{
|
||||||
Type: mType,
|
Type: mType,
|
||||||
TransactionID: NewTransactionID(),
|
TransactionID: NewTransactionID(),
|
||||||
Attributes: messageAttributes,
|
Attributes: messageAttributes,
|
||||||
|
})
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
m.WriteTo(b)
|
||||||
|
buf := b.Bytes()
|
||||||
|
binary.BigEndian.PutUint16(buf[2:4], 5) // rewrite to bad length
|
||||||
|
mDecoded := AcquireMessage()
|
||||||
|
_, err := mDecoded.ReadFrom(bytes.NewReader(buf[:20+5]))
|
||||||
|
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
||||||
|
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
buf := make([]byte, 128)
|
}
|
||||||
m.Put(buf)
|
|
||||||
binary.BigEndian.PutUint16(buf[2:4], 2) // rewrite to bad length
|
type unexpectedEOFReader struct{}
|
||||||
mDecoded := Message{}
|
|
||||||
err := mDecoded.Get(buf[:20+5])
|
func (_ unexpectedEOFReader) Read(b []byte) (int, error) {
|
||||||
|
return 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessage_ReadFromError(t *testing.T) {
|
||||||
|
mDecoded := AcquireMessage()
|
||||||
|
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
|
||||||
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
if errors.Cause(err) != io.ErrUnexpectedEOF {
|
||||||
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
||||||
}
|
}
|
||||||
@@ -249,10 +287,11 @@ func BenchmarkMessage_Put(b *testing.B) {
|
|||||||
Length: 0,
|
Length: 0,
|
||||||
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||||
}
|
}
|
||||||
buf := make([]byte, 20)
|
buf := new(bytes.Buffer)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
m.Put(buf)
|
m.WriteTo(buf)
|
||||||
|
buf.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user