all: refactor attributes

This commit is contained in:
Aleksandr Razumov
2017-02-11 05:23:49 +03:00
parent 3ec51bf758
commit fba9f6b02d
21 changed files with 1048 additions and 1107 deletions

View File

@@ -1,4 +1,4 @@
Copyright (c) 2016 Aleksandr Razumov, Cydev. All Rigths Reserved. Copyright (c) 2016-2017 Aleksandr Razumov, Cydev. All Rigths Reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are

View File

@@ -2,39 +2,88 @@ package stun
// ErrorCodeAttribute represents ERROR-CODE attribute. // ErrorCodeAttribute represents ERROR-CODE attribute.
type ErrorCodeAttribute struct { type ErrorCodeAttribute struct {
Code int Code ErrorCode
Reason []byte Reason []byte
} }
// constants for ERROR-CODE encoding.
const (
errorCodeReasonStart = 4
errorCodeClassByte = 2
errorCodeNumberByte = 3
errorCodeReasonMaxB = 763
errorCodeModulo = 100
)
// AddTo adds ERROR-CODE to m.
func (c *ErrorCodeAttribute) AddTo(m *Message) error {
value := make([]byte,
errorCodeReasonStart, errorCodeReasonMaxB,
)
number := byte(c.Code % errorCodeModulo) // error code modulo 100
class := byte(c.Code / errorCodeModulo) // hundred digit
value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number
value = append(value, c.Reason...)
m.Add(AttrErrorCode, value)
return nil
}
// GetFrom decodes ERROR-CODE from m.
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
v, err := m.Get(AttrErrorCode)
if err != nil {
return err
}
var (
class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte])
code = int(class*errorCodeModulo + number)
reason = v[errorCodeReasonStart:]
)
c.Code = ErrorCode(code)
c.Reason = reason
return nil
}
// ErrorCode is code for ERROR-CODE attribute. // ErrorCode is code for ERROR-CODE attribute.
type ErrorCode int type ErrorCode int
// ErrNoDefaultReason means that default reason for provided error code
// is not defined in RFC.
const ErrNoDefaultReason Error = "No default reason for ErrorCode"
// AddTo adds ERROR-CODE with default reason to m. If there
// is no default reason, returns ErrNoDefaultReason.
func (c ErrorCode) AddTo(m *Message) error {
reason := errorReasons[c]
if reason == nil {
return ErrNoDefaultReason
}
a := &ErrorCodeAttribute{
Code: c,
Reason: reason,
}
return a.AddTo(m)
}
// Possible error codes. // Possible error codes.
const ( const (
CodeTryAlternate = 300 CodeTryAlternate ErrorCode = 300
CodeBadRequest = 400 CodeBadRequest ErrorCode = 400
CodeUnauthorised = 401 CodeUnauthorised ErrorCode = 401
CodeUnknownAttribute = 420 CodeUnknownAttribute ErrorCode = 420
CodeStaleNonce = 428 CodeStaleNonce ErrorCode = 428
CodeRoleConflict = 478 CodeRoleConflict ErrorCode = 478
CodeServerError = 500 CodeServerError ErrorCode = 500
) )
var errorReasons = map[int]string{ var errorReasons = map[ErrorCode][]byte{
CodeTryAlternate: "Try Alternate", CodeTryAlternate: []byte("Try Alternate"),
CodeBadRequest: "Bad Request", CodeBadRequest: []byte("Bad Request"),
CodeUnauthorised: "Unauthorised", CodeUnauthorised: []byte("Unauthorised"),
CodeUnknownAttribute: "Unknown Attribute", CodeUnknownAttribute: []byte("Unknown Attribute"),
CodeStaleNonce: "Stale Nonce", CodeStaleNonce: []byte("Stale Nonce"),
CodeServerError: "Server Error", CodeServerError: []byte("Server Error"),
CodeRoleConflict: "Role Conflict", CodeRoleConflict: []byte("Role Conflict"),
}
// Reason returns recommended reason string.
func (c ErrorCode) Reason() string {
reason, ok := errorReasons[int(c)]
if !ok {
return "Unknown Error"
}
return reason
} }

View File

@@ -1,26 +0,0 @@
package stun
import "testing"
func TestErrorCode_Reason(t *testing.T) {
codes := [...]ErrorCode{
CodeTryAlternate,
CodeBadRequest,
CodeUnauthorised,
CodeUnknownAttribute,
CodeStaleNonce,
CodeRoleConflict,
CodeServerError,
}
for _, code := range codes {
if code.Reason() == "Unknown Error" {
t.Error(code, "should not be unknown")
}
if len(code.Reason()) == 0 {
t.Error(code, "should not be blank")
}
}
if ErrorCode(999).Reason() != "Unknown Error" {
t.Error("999 error should be Unknown")
}
}

68
attribute_fingerprint.go Normal file
View File

@@ -0,0 +1,68 @@
package stun
import (
"fmt"
"hash/crc32"
)
// FingerprintAttr represent FINGERPRINT attribute.
type FingerprintAttr struct{}
// CRCMismatch represents CRC check error.
type CRCMismatch struct {
Expected uint32
Actual uint32
}
func (m CRCMismatch) Error() string {
return fmt.Sprintf("CRC mismatch: %x (expected) != %x (actual)",
m.Expected,
m.Actual,
)
}
// Fingerprint is shorthand for FingerprintAttr.
var Fingerprint = &FingerprintAttr{}
const (
fingerprintXORValue uint32 = 0x5354554e
fingerprintSize = 4 // 32 bit
)
// FingerprintValue returns CRC32 of m XOR-ed by 0x5354554e.
func FingerprintValue(b []byte) uint32 {
return crc32.ChecksumIEEE(b) ^ fingerprintXORValue // XOR
}
// AddTo adds fingerprint to message.
func (FingerprintAttr) AddTo(m *Message) error {
l := m.Length
// length in header should include size of fingerprint attribute
m.Length += fingerprintSize + attributeHeaderSize // increasing length
m.WriteLength() // writing Length to Raw
b := make([]byte, fingerprintSize)
v := FingerprintValue(m.Raw)
bin.PutUint32(b, v)
m.Length = l
m.Add(AttrFingerprint, b)
return nil
}
// Check reads fingerprint value from m and checks it, returning error if any.
// Can return *DecodeErr, ErrAttributeNotFound and *CRCMismatch.
func (FingerprintAttr) Check(m *Message) error {
v, err := m.Get(AttrFingerprint)
if err != nil {
return err
}
if len(v) != fingerprintSize {
return newDecodeErr("message", "fingerprint", "bad length")
}
val := bin.Uint32(v)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := FingerprintValue(m.Raw[:attrStart])
if expected != val {
return &CRCMismatch{Expected: expected, Actual: val}
}
return nil
}

31
attribute_software.go Normal file
View File

@@ -0,0 +1,31 @@
package stun
// Software is SOFTWARE attribute.
type Software struct {
Raw []byte
}
func (s *Software) String() string {
return string(s.Raw)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) *Software {
return &Software{Raw: []byte(software)}
}
// AddTo adds Software attribute to m.
func (s *Software) AddTo(m *Message) error {
m.Add(AttrSoftware, m.Raw)
return nil
}
// GetFrom decodes Software from m.
func (s *Software) GetFrom(m *Message) error {
v, err := m.Get(AttrSoftware)
if err != nil {
return err
}
s.Raw = v
return nil
}

118
attribute_xoraddr.go Normal file
View File

@@ -0,0 +1,118 @@
package stun
import (
"fmt"
"net"
)
const (
familyIPv4 byte = 0x01
familyIPv6 byte = 0x02
)
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
type XORMappedAddress struct {
IP net.IP
Port int
}
func (a XORMappedAddress) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// Is p all zeros?
func isZeros(p net.IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
const ErrBadIPLength Error = "invalid length if IP value"
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
// if len(a.IP) is invalid.
func (a *XORMappedAddress) AddTo(m *Message) error {
var (
family = familyIPv4
ip = a.IP
)
if len(a.IP) == net.IPv6len {
// Optimized for performance. See net.IP.To4 method.
if isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
ip = ip[12:16] // like in ip.To4()
} else {
family = familyIPv6
}
} else if len(ip) != net.IPv4len {
return ErrBadIPLength
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
bin.PutUint32(xorValue[0:4], magicCookie)
bin.PutUint16(value[0:2], uint16(family))
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
xorBytes(value[4:4+len(ip)], ip, xorValue)
m.Add(AttrXORMappedAddress, value[:4+len(ip)])
return nil
}
// GetFrom decodes XOR-MAPPED-ADDRESS attribute in message and returns
// error if any. While decoding, a.IP is reused if possible and can be
// rendered to invalid state (e.g. if a.IP was set to IPv6 and then
// IPv4 value were decoded into it), be careful.
//
// Example:
//
// expectedIP := net.ParseIP("213.141.156.236")
// expectedIP.String() // 213.141.156.236, 16 bytes, first 12 of them are zeroes
// expectedPort := 21254
// addr := &XORMappedAddress{
// IP: expectedIP,
// Port: expectedPort,
// }
// // addr were added to message that is decoded as newMessage
// // ...
//
// addr.GetFrom(newMessage)
// addr.IP.String() // 213.141.156.236, net.IPv4Len
// expectedIP.String() // d58d:9cec::ffff:d58d:9cec, 16 bytes, first 4 are IPv4
// // now we have len(expectedIP) = 16 and len(addr.IP) = 4.
func (a *XORMappedAddress) GetFrom(m *Message) error {
v, err := m.Get(AttrXORMappedAddress)
if err != nil {
return err
}
family := byte(bin.Uint16(v[0:2]))
if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == familyIPv6 {
ipLen = net.IPv6len
}
// Ensuring len(a.IP) == ipLen and reusing a.IP.
if len(a.IP) < ipLen {
a.IP = a.IP[:cap(a.IP)]
for len(a.IP) < ipLen {
a.IP = append(a.IP, 0)
}
}
a.IP = a.IP[:ipLen]
for i := range a.IP {
a.IP[i] = 0
}
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+transactionIDSize)
bin.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(a.IP, v[4:], xorValue)
return nil
}

View File

@@ -1,23 +1,10 @@
package stun package stun
import ( import (
"encoding/binary"
"fmt" "fmt"
"hash/crc32"
"net"
"strconv" "strconv"
) )
// AttrWriter wraps AddRaw method.
type AttrWriter interface {
AddRaw(t AttrType, v []byte)
}
// AttrEncoder wraps Encode method.
type AttrEncoder interface {
Encode(b []byte, m *Message) (AttrType, []byte, error)
}
// Attributes is list of message attributes. // Attributes is list of message attributes.
type Attributes []RawAttribute type Attributes []RawAttribute
@@ -77,7 +64,7 @@ const (
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
) )
// Attributes from An Origin RawAttribute for the STUN Protocol. // Attributes from An Origin Attribute for the STUN Protocol.
const ( const (
AttrOrigin AttrType = 0x802F AttrOrigin AttrType = 0x802F
) )
@@ -118,7 +105,7 @@ var attrNames = map[AttrType]string{
func (t AttrType) String() string { func (t AttrType) String() string {
s, ok := attrNames[t] s, ok := attrNames[t]
if !ok { if !ok {
// just return hex representation of unknown attribute type // Just return hex representation of unknown attribute type.
return "0x" + strconv.FormatUint(uint64(t), 16) return "0x" + strconv.FormatUint(uint64(t), 16)
} }
return s return s
@@ -131,21 +118,17 @@ func (t AttrType) String() string {
// don't understand, but cannot successfully process a message if it // don't understand, but cannot successfully process a message if it
// contains comprehension-required attributes that are not // contains comprehension-required attributes that are not
// understood. // understood.
//
// TODO(ar): Decide to use pointer or non-pointer RawAttribute.
type RawAttribute struct { type RawAttribute struct {
Type AttrType Type AttrType
Length uint16 // ignored while encoding Length uint16 // ignored while encoding
Value []byte Value []byte
} }
// Encode implements AttrEncoder. // AddTo adds RawAttribute to m.
func (a *RawAttribute) Encode(m *Message) ([]byte, error) { func (a *RawAttribute) AddTo(m *Message) error {
return m.Raw, nil m.Add(a.Type, m.Raw)
}
// Decode implements AttrDecoder.
func (a *RawAttribute) Decode(v []byte, m *Message) error {
a.Value = v
a.Length = uint16(len(v))
return nil return nil
} }
@@ -172,319 +155,17 @@ func (a RawAttribute) String() string {
return fmt.Sprintf("%s: %x", a.Type, a.Value) return fmt.Sprintf("%s: %x", a.Type, a.Value)
} }
// getAttrValue returns byte slice that represents attribute value, // ErrAttributeNotFound means that attribute with provided attribute
// if there is no value attribute with shuck type, // type does not exist in message.
const ErrAttributeNotFound Error = "Attribute not found"
// Get returns byte slice that represents attribute value,
// if there is no attribute with such type,
// ErrAttributeNotFound is returned. // ErrAttributeNotFound is returned.
func (m *Message) getAttrValue(t AttrType) ([]byte, error) { func (m *Message) Get(t AttrType) ([]byte, error) {
v, ok := m.Attributes.Get(t) v, ok := m.Attributes.Get(t)
if !ok { if !ok {
return nil, ErrAttributeNotFound return nil, ErrAttributeNotFound
} }
return v.Value, nil return v.Value, nil
} }
// AddSoftware adds SOFTWARE attribute with value from string.
// Deprecated: use AddRaw.
func (m *Message) AddSoftware(software string) {
m.AddRaw(AttrSoftware, []byte(software))
}
// Set sets the value of attribute if it presents.
func (m *Message) Set(a AttrEncoder) error {
var (
v []byte
err error
t AttrType
)
t, v, err = a.Encode(v, m)
if err != nil {
return err
}
buf, err := m.getAttrValue(t)
if err != nil {
return err
}
if len(v) != len(buf) {
return ErrBadSetLength
}
copy(buf, v)
return nil
}
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
// If not found, returns nil.
func (m *Message) GetSoftwareBytes() []byte {
v, ok := m.Attributes.Get(AttrSoftware)
if !ok {
return nil
}
return v.Value
}
// GetSoftware returns SOFTWARE attribute value in string.
// If not found, returns blank string.
// Deprecated.
func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) }
// Address family values.
const (
FamilyIPv4 byte = 0x01
FamilyIPv6 byte = 0x02
)
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
type XORMappedAddress struct {
ip net.IP
port int
}
// Encode implements AttrEncoder.
func (a *XORMappedAddress) Encode(buf []byte, m *Message) (AttrType, []byte, error) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
family := FamilyIPv6
ip := a.ip
port := a.port
if ipV4 := ip.To4(); ipV4 != nil {
ip = ipV4
family = FamilyIPv4
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
port ^= magicCookie >> 16
binary.BigEndian.PutUint16(value[0:2], uint16(family))
binary.BigEndian.PutUint16(value[2:4], uint16(port))
xorBytes(value[4:4+len(ip)], ip, xorValue)
buf = append(buf, value...)
return AttrXORMappedAddress, buf, nil
}
// Decode implements AttrDecoder.
// TODO(ar): fix signature.
func (a *XORMappedAddress) Decode(v []byte, m *Message) error {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
v, err := m.getAttrValue(AttrXORMappedAddress)
if err != nil {
return err
}
family := byte(binary.BigEndian.Uint16(v[0:2]))
if family != FamilyIPv6 && family != FamilyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == FamilyIPv6 {
ipLen = net.IPv6len
}
ip := net.IP(m.allocBuffer(ipLen))
a.port = int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 128)
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(ip, v[4:], xorValue)
a.ip = ip
return nil
}
// AddXORMappedAddress adds XOR MAPPED ADDRESS attribute to message.
// Deprecated: use AddRaw.
func (m *Message) AddXORMappedAddress(ip net.IP, port int) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
family := FamilyIPv6
if ipV4 := ip.To4(); ipV4 != nil {
ip = ipV4
family = FamilyIPv4
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
port ^= magicCookie >> 16
binary.BigEndian.PutUint16(value[0:2], uint16(family))
binary.BigEndian.PutUint16(value[2:4], uint16(port))
xorBytes(value[4:4+len(ip)], ip, xorValue)
m.AddRaw(AttrXORMappedAddress, value[:4+len(ip)])
}
func (m *Message) allocBuffer(size int) []byte {
capacity := len(m.Raw) + size
m.grow(capacity)
m.Raw = m.Raw[:capacity]
return m.Raw[len(m.Raw)-size:]
}
// GetXORMappedAddress returns ip, port from attribute and error if any.
// Value for ip is valid until Message is released or underlying buffer is
// corrupted. Returns *DecodeError or ErrAttributeNotFound.
// Deprecated: use GetRaw.
func (m *Message) GetXORMappedAddress() (net.IP, int, error) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
v, err := m.getAttrValue(AttrXORMappedAddress)
if err != nil {
return nil, 0, err
}
family := byte(binary.BigEndian.Uint16(v[0:2]))
if family != FamilyIPv6 && family != FamilyIPv4 {
return nil, 0, newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == FamilyIPv6 {
ipLen = net.IPv6len
}
ip := net.IP(m.allocBuffer(ipLen))
port := int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 128)
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(ip, v[4:], xorValue)
return ip, port, nil
}
// constants for ERROR-CODE encoding.
const (
errorCodeReasonStart = 4
errorCodeClassByte = 2
errorCodeNumberByte = 3
errorCodeReasonMaxB = 763
errorCodeModulo = 100
)
// AddErrorCode adds ERROR-CODE attribute to message.
//
// The reason phrase MUST be a UTF-8 [RFC 3629] encoded
// sequence of less than 128 characters (which can be as long as 763
// bytes).
// Deprecated: use AddRaw.
func (m *Message) AddErrorCode(code int, reason string) {
value := make([]byte,
errorCodeReasonStart, errorCodeReasonMaxB+errorCodeReasonStart,
)
number := byte(code % errorCodeModulo) // error code modulo 100
class := byte(code / errorCodeModulo) // hundred digit
value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number
value = append(value, reason...)
m.AddRaw(AttrErrorCode, value)
}
// AddErrorCodeDefault is wrapper for AddErrorCode that uses recommended
// reason string from RFC. If error code is unknown, reason will be "Unknown
// Error".
// Deprecated: use AddRaw.
func (m *Message) AddErrorCodeDefault(code int) {
m.AddErrorCode(code, ErrorCode(code).Reason())
}
// GetErrorCode returns ERROR-CODE code, reason and decode error if any.
// Deprecated: use GetRaw.
func (m *Message) GetErrorCode() (int, []byte, error) {
v, err := m.getAttrValue(AttrErrorCode)
if err != nil {
return 0, nil, err
}
var (
class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte])
code = int(class*errorCodeModulo + number)
reason = v[errorCodeReasonStart:]
)
return code, reason, nil
}
const (
// ErrAttributeNotFound means that there is no such attribute.
ErrAttributeNotFound Error = "Attribute not found"
// ErrBadSetLength means that previous attribute value length differs from
// new value.
ErrBadSetLength Error = "Previous attribute length is different"
)
// Software is SOFTWARE attribute.
type Software struct {
Raw []byte
}
func (s Software) String() string {
return string(s.Raw)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) *Software {
return &Software{Raw: []byte(software)}
}
// Encode implements AttrEncoder.
func (s *Software) Encode(b []byte, m *Message) (AttrType, []byte, error) {
return AttrSoftware, append(b, s.Raw...), nil
}
// Decode implements AttrDecoder.
func (s *Software) Decode(v []byte, m *Message) error {
s.Raw = v
return nil
}
const (
fingerprintXORValue uint32 = 0x5354554e
)
// Fingerprint represents FINGERPRINT attribute.
type Fingerprint struct {
Value uint32 // CRC-32 of message XOR-ed with 0x5354554e
}
const (
fingerprintSize = 4 // 32 bit
)
// AddTo adds fingerprint to message.
func (f *Fingerprint) AddTo(m *Message) error {
l := m.Length
// length in header should include size of fingerprint attribute
m.Length += fingerprintSize + attributeHeaderSize // increasing length
m.WriteLength() // writing Length to Raw
b := make([]byte, fingerprintSize)
f.Value = crc32.ChecksumIEEE(m.Raw) ^ fingerprintXORValue // XOR
bin.PutUint32(b, f.Value)
m.Length = l
m.AddRaw(AttrFingerprint, b)
return nil
}
// Check reads fingerprint value from m and checks it, returning error if any.
// Can return *DecodeErr, ErrAttributeNotFound, ErrCRCMissmatch.
func (f *Fingerprint) Check(m *Message) error {
v, err := m.getAttrValue(AttrFingerprint)
if err != nil {
return err
}
if len(v) != fingerprintSize {
return newDecodeErr("message", "fingerprint", "bad length")
}
f.Value = bin.Uint32(v)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := crc32.ChecksumIEEE(m.Raw[:attrStart]) ^ fingerprintXORValue
if expected != f.Value {
return ErrCRCMissmatch
}
return nil
}
// ErrCRCMissmatch means that calculated fingerprint attribute differs from
// expected one.
const ErrCRCMissmatch Error = "CRC32 missmatch: bad fingerprint value"

View File

@@ -10,21 +10,24 @@ import (
"testing" "testing"
) )
func TestMessage_AddSoftware(t *testing.T) { func TestSoftware_GetFrom(t *testing.T) {
m := New() m := New()
v := "Client v0.0.1" v := "Client v0.0.1"
m.AddRaw(AttrSoftware, []byte(v)) m.Add(AttrSoftware, []byte(v))
m.WriteHeader() m.WriteHeader()
m2 := &Message{ m2 := &Message{
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil { if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err) t.Error(err)
} }
vRead := m.GetSoftware() if err := software.GetFrom(m); err != nil {
if vRead != v { t.Fatal(err)
t.Errorf("Expected %s, got %s.", v, vRead) }
if software.String() != v {
t.Errorf("Expected %q, got %q.", v, software)
} }
sAttr, ok := m.Attributes.Get(AttrSoftware) sAttr, ok := m.Attributes.Get(AttrSoftware)
@@ -37,29 +40,18 @@ func TestMessage_AddSoftware(t *testing.T) {
} }
} }
func TestMessage_GetSoftware(t *testing.T) { func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
m := New()
v := m.GetSoftware()
if v != "" {
t.Errorf("%s should be blank.", v)
}
vByte := m.GetSoftwareBytes()
if vByte != nil {
t.Errorf("%s should be nil.", vByte)
}
}
func BenchmarkMessage_AddXORMappedAddress(b *testing.B) {
m := New() m := New()
b.ReportAllocs() b.ReportAllocs()
ip := net.ParseIP("192.168.1.32") ip := net.ParseIP("192.168.1.32")
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.AddXORMappedAddress(ip, 3654) addr := &XORMappedAddress{IP: ip, Port: 3654}
addr.AddTo(m)
m.Reset() m.Reset()
} }
} }
func BenchmarkMessage_GetXORMappedAddress(b *testing.B) { func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
m := New() m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { if err != nil {
@@ -70,10 +62,13 @@ func BenchmarkMessage_GetXORMappedAddress(b *testing.B) {
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
m.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.AddRaw(AttrXORMappedAddress, addrValue) if err := addr.GetFrom(m); err != nil {
m.GetXORMappedAddress() b.Fatal(err)
m.Reset() }
} }
} }
@@ -88,16 +83,16 @@ func TestMessage_GetXORMappedAddress(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
m.AddRaw(AttrXORMappedAddress, addrValue) m.Add(AttrXORMappedAddress, addrValue)
ip, port, err := m.GetXORMappedAddress() addr := new(XORMappedAddress)
if err != nil { if err = addr.GetFrom(m); err != nil {
t.Error(err) t.Error(err)
} }
if !ip.Equal(net.ParseIP("213.141.156.236")) { if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
t.Error("bad ip", ip, "!=", "213.141.156.236") t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
} }
if port != 48583 { if addr.Port != 48583 {
t.Error("bad port", port, "!=", 48583) t.Error("bad Port", addr.Port, "!=", 48583)
} }
} }
@@ -110,13 +105,15 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
addr := new(XORMappedAddress)
_, _, err = m.GetXORMappedAddress() if err = addr.GetFrom(m); err == nil {
if err == nil {
t.Fatal(err, "should be nil") t.Fatal(err, "should be nil")
} }
m.AddXORMappedAddress(expectedIP, expectedPort) addr.IP = expectedIP
addr.Port = expectedPort
addr.AddTo(m)
m.WriteHeader() m.WriteHeader()
mRes := New() mRes := New()
@@ -124,8 +121,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil { if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, err = m.GetXORMappedAddress() if err = addr.GetFrom(m); err == nil {
if err == nil {
t.Fatal(err, "should not be nil") t.Fatal(err, "should not be nil")
} }
} }
@@ -139,22 +135,26 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
m.AddXORMappedAddress(expectedIP, expectedPort) addr := &XORMappedAddress{
IP: net.ParseIP("213.141.156.236"),
Port: expectedPort,
}
if err = addr.AddTo(m); err != nil {
t.Fatal(err)
}
m.WriteHeader() m.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil { if _, err = mRes.Write(m.Raw); err != nil {
t.Fatal(err) t.Fatal(err)
} }
ip, port, err := m.GetXORMappedAddress() if err = addr.GetFrom(mRes); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !ip.Equal(expectedIP) { if !addr.IP.Equal(expectedIP) {
t.Error("bad ip", ip, "!=", expectedIP) t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
} }
if port != expectedPort { if addr.Port != expectedPort {
t.Error("bad port", port, "!=", expectedPort) t.Error("bad Port", addr.Port, "!=", expectedPort)
} }
} }
@@ -167,30 +167,47 @@ func TestMessage_AddXORMappedAddressV6(t *testing.T) {
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254 expectedPort := 21254
m.AddXORMappedAddress(expectedIP, expectedPort) addr := &XORMappedAddress{
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
Port: 21254,
}
addr.AddTo(m)
m.WriteHeader() m.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil { if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
ip, port, err := m.GetXORMappedAddress() gotAddr := new(XORMappedAddress)
if err != nil { if err = gotAddr.GetFrom(m); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !ip.Equal(expectedIP) { if !gotAddr.IP.Equal(expectedIP) {
t.Error("bad ip", ip, "!=", expectedIP) t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
} }
if port != expectedPort { if gotAddr.Port != expectedPort {
t.Error("bad port", port, "!=", expectedPort) t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
} }
} }
func BenchmarkMessage_AddErrorCode(b *testing.B) { func BenchmarkErrorCode_AddTo(b *testing.B) {
m := New() m := New()
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.AddErrorCode(404, "Not found") CodeStaleNonce.AddTo(m)
m.Reset()
}
}
func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) {
m := New()
b.ReportAllocs()
a := &ErrorCodeAttribute{
Code: 404,
Reason: []byte("not found!"),
}
for i := 0; i < b.N; i++ {
a.AddTo(m)
m.Reset() m.Reset()
} }
} }
@@ -202,51 +219,27 @@ func TestMessage_AddErrorCode(t *testing.T) {
t.Error(err) t.Error(err)
} }
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
expectedCode := 404 expectedCode := ErrorCode(428)
expectedReason := "Not found" expectedReason := "Stale Nonce"
m.AddErrorCode(expectedCode, expectedReason) CodeStaleNonce.AddTo(m)
m.WriteHeader() m.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil { if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
code, reason, err := mRes.GetErrorCode() errCodeAttr := new(ErrorCodeAttribute)
if err = errCodeAttr.GetFrom(mRes); err != nil {
t.Error(err)
}
code := errCodeAttr.Code
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if code != expectedCode { if code != expectedCode {
t.Error("bad code", code) t.Error("bad code", code)
} }
if string(reason) != expectedReason { if string(errCodeAttr.Reason) != expectedReason {
t.Error("bad reason", string(reason)) t.Error("bad reason", string(errCodeAttr.Reason))
}
}
func TestMessage_AddErrorCodeDefault(t *testing.T) {
m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
expectedCode := 500
expectedReason := "Server Error"
m.AddErrorCodeDefault(expectedCode)
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
code, reason, err := mRes.GetErrorCode()
if err != nil {
t.Error(err)
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(reason) != expectedReason {
t.Error("bad reason", string(reason))
} }
} }

View File

@@ -15,6 +15,12 @@ type DecodeErr struct {
Message string Message string
} }
// IsInvalidCookie returns true if error means that magic cookie
// value is invalid.
func (e DecodeErr) IsInvalidCookie() bool {
return e.Place == DecodeErrPlace{"message", "cookie"}
}
// IsPlaceParent reports if error place parent is p. // IsPlaceParent reports if error place parent is p.
func (e DecodeErr) IsPlaceParent(p string) bool { func (e DecodeErr) IsPlaceParent(p string) bool {
return e.Place.Parent == p return e.Place.Parent == p
@@ -50,3 +56,8 @@ func newDecodeErr(parent, children, message string) *DecodeErr {
Message: message, Message: message,
} }
} }
// TODO(ar): rewrite errors to be more precise.
func newAttrDecodeErr(children, message string) *DecodeErr {
return newDecodeErr("attribute", children, message)
}

View File

@@ -12,12 +12,12 @@ func FuzzMessage(data []byte) int {
// fuzzer dont know about cookies // fuzzer dont know about cookies
binary.BigEndian.PutUint32(data[4:8], magicCookie) binary.BigEndian.PutUint32(data[4:8], magicCookie)
// trying to read data as message // trying to read data as message
if _, err := m.ReadBytes(data); err != nil { if _, err := m.Write(data); err != nil {
return 0 return 0
} }
m.WriteHeader() m.WriteHeader()
m2 := New() m2 := New()
if _, err := m2.ReadBytes(m2.Raw); err != nil { if _, err := m2.Write(m2.Raw); err != nil {
panic(err) panic(err)
} }
if m2.TransactionID != m.TransactionID { if m2.TransactionID != m.TransactionID {

491
message.go Normal file
View File

@@ -0,0 +1,491 @@
// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389.
//
// Definitions
//
// STUN Agent: A STUN agent is an entity that implements the STUN
// protocol. The entity can be either a STUN client or a STUN
// server.
//
// STUN Client: A STUN client is an entity that sends STUN requests and
// receives STUN responses. A STUN client can also send indications.
// In this specification, the terms STUN client and client are
// synonymous.
//
// STUN Server: A STUN server is an entity that receives STUN requests
// and sends STUN responses. A STUN server can also send
// indications. In this specification, the terms STUN server and
// server are synonymous.
//
// Transport Address: The combination of an IP address and Port number
// (such as a UDP or TCP Port number).
package stun
import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"strconv"
)
const (
// magicCookie is fixed value that aids in distinguishing STUN packets
// from packets of other protocols when STUN is multiplexed with those
// other protocols on the same Port.
//
// The magic cookie field MUST contain the fixed value 0x2112A442 in
// network byte order.
//
// Defined in "STUN Message Structure", section 6.
magicCookie = 0x2112A442
attributeHeaderSize = 4
messageHeaderSize = 20
transactionIDSize = 12 // 96 bit
)
// NewTransactionID returns new random transaction ID using crypto/rand
// as source.
func NewTransactionID() (b [transactionIDSize]byte) {
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return b
}
// IsMessage returns true if b looks like STUN message.
// Useful for multiplexing. IsMessage does not guarantee
// that decoding will be successful.
func IsMessage(b []byte) bool {
return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
}
// New returns *Message with pre-allocated Raw.
func New() *Message {
const defaultRawCapacity = 120
return &Message{
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
}
}
// Message represents a single STUN packet. It uses aggressive internal
// buffering to enable zero-allocation encoding and decoding,
// so there are some usage constraints:
//
// * Message and its fields is valid only until AcquireMessage call.
type Message struct {
Type MessageType
Length uint32 // len(Raw) not including header
TransactionID [transactionIDSize]byte
Attributes Attributes
Raw []byte
}
// NewTransactionID sets m.TransactionID to random value from crypto/rand
// and returns error if any.
func (m *Message) NewTransactionID() error {
_, err := rand.Read(m.TransactionID[:])
return err
}
func (m Message) String() string {
return fmt.Sprintf("%s l=%d attrs=%d id=%s",
m.Type,
m.Length,
len(m.Attributes),
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
)
}
// Reset resets Message, attributes and underlying buffer length.
func (m *Message) Reset() {
m.Raw = m.Raw[:0]
m.Length = 0
m.Attributes = m.Attributes[:0]
}
// grow ensures that internal buffer will fit v more bytes and
// increases it capacity if necessary.
func (m *Message) grow(v int) {
// Not performing any optimizations here
// (e.g. preallocate len(buf) * 2 to reduce allocations)
// because they are already done by []byte implementation.
n := len(m.Raw) + v
for cap(m.Raw) < n {
m.Raw = append(m.Raw, 0)
}
m.Raw = m.Raw[:n]
}
// Add appends new attribute to message. Not goroutine-safe.
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func (m *Message) Add(t AttrType, v []byte) {
// Allocating buffer for TLV (type-length-value).
// T = t, L = len(v), V = v.
// m.Raw will look like:
// [0:20] <- message header
// [20:20+m.Length] <- existing message attributes
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
// [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
// T L V
allocSize := attributeHeaderSize + len(v) // len(TLV) = len(TL) + len(V)
first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number
m.grow(last) // growing cap(Raw) to fit TLV
m.Raw = m.Raw[:last] // now len(Raw) = last
m.Length += uint32(allocSize) // rendering length change
// Sub-slicing internal buffer to simplify encoding.
buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for V
attr := RawAttribute{
Type: t, // T
Length: uint16(len(v)), // L
Value: value, // V
}
// Encoding attribute TLV to allocated buffer.
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
bin.PutUint16(buf[2:4], attr.Length) // L
copy(value, v) // V
// Checking that attribute value needs padding.
if attr.Length%padding != 0 {
// Performing padding.
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
last += bytesToAdd
m.grow(last)
// setting all padding bytes to zero
// to prevent data leak from previous
// data in next bytesToAdd bytes
buf = m.Raw[last-bytesToAdd : last]
for i := range buf {
buf[i] = 0
}
m.Raw = m.Raw[:last] // increasing buffer length
m.Length += uint32(bytesToAdd) // rendering length change
}
m.Attributes = append(m.Attributes, attr)
}
// Equal returns true if Message b equals to m.
// Ignores m.Raw.
func (m *Message) Equal(b *Message) bool {
if m.Type != b.Type {
return false
}
if m.TransactionID != b.TransactionID {
return false
}
if m.Length != b.Length {
return false
}
for _, a := range m.Attributes {
aB, ok := b.Attributes.Get(a.Type)
if !ok {
return false
}
if !aB.Equal(a) {
return false
}
}
return true
}
// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4.
func (m *Message) WriteLength() {
_ = m.Raw[4] // early bounds check to guarantee safety of writes below
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
}
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
func (m *Message) WriteHeader() {
if len(m.Raw) < messageHeaderSize {
// Making WriteHeader call valid even when m.Raw
// is nil or len(m.Raw) is less than needed for header.
m.grow(messageHeaderSize)
}
_ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize)) // size of payload
bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
}
// WriteAttributes encodes all m.Attributes to m.
func (m *Message) WriteAttributes() {
for _, a := range m.Attributes {
m.Add(a.Type, a.Value)
}
}
// Encode resets m.Raw and calls WriteHeader and WriteAttributes.
func (m *Message) Encode() {
m.Raw = m.Raw[:0]
m.WriteHeader()
m.WriteAttributes()
}
// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
// call result.
func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw)
return int64(n), err
}
// Append appends m.Raw to v. Useful to call after encoding message.
func (m *Message) Append(v []byte) []byte {
return append(v, m.Raw...)
}
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
//
// Can return *DecodeErr while decoding too.
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
tBuf := m.Raw[:cap(m.Raw)]
var (
n int
err error
)
if n, err = r.Read(tBuf); err != nil {
return int64(n), err
}
m.Raw = tBuf[:n]
return int64(n), m.Decode()
}
const (
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header.
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
)
// Decode decodes m.Raw into m.
func (m *Message) Decode() error {
// decoding message header
buf := m.Raw
if len(buf) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
cookie = binary.BigEndian.Uint32(buf[4:8])
fullSize = messageHeaderSize + size
)
if cookie != magicCookie {
msg := fmt.Sprintf(
"%x is invalid magic cookie (should be %x)",
cookie, magicCookie,
)
return newDecodeErr("message", "cookie", msg)
}
if len(buf) < fullSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected message size)",
len(buf), fullSize,
)
return newAttrDecodeErr("message", msg)
}
// saving header data
m.Type.ReadValue(t)
m.Length = uint32(size)
copy(m.TransactionID[:], buf[8:messageHeaderSize])
var (
offset = 0
b = buf[messageHeaderSize:fullSize]
)
for offset < size {
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected header size)",
len(b), attributeHeaderSize,
)
return newAttrDecodeErr("header", msg)
}
var (
a = RawAttribute{
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes
}
aL = int(a.Length) // attribute length
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
)
b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize
if len(b) < aBuffL { // checking size
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected value size)",
len(b), aBuffL,
)
return newAttrDecodeErr("value", msg)
}
a.Value = b[:aL]
offset += aBuffL
b = b[aBuffL:]
m.Attributes = append(m.Attributes, a)
}
return nil
}
// Write decodes message and return error if any.
//
// Any error is unrecoverable, but message could be partially decoded.
func (m *Message) Write(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode()
}
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
const MaxPacketSize = 2048
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte
// Possible values for message class in STUN Message Type.
const (
ClassRequest MessageClass = 0x00 // 0b00
ClassIndication MessageClass = 0x01 // 0b01
ClassSuccessResponse MessageClass = 0x02 // 0b10
ClassErrorResponse MessageClass = 0x03 // 0b11
)
func (c MessageClass) String() string {
switch c {
case ClassRequest:
return "request"
case ClassIndication:
return "indication"
case ClassSuccessResponse:
return "success response"
case ClassErrorResponse:
return "error response"
default:
panic("unknown message class")
}
}
// Method is uint16 representation of 12-bit STUN method.
type Method uint16
// Possible methods for STUN Message.
const (
MethodBinding Method = 0x001
MethodAllocate Method = 0x003
MethodRefresh Method = 0x004
MethodSend Method = 0x006
MethodData Method = 0x007
MethodCreatePermission Method = 0x008
MethodChannelBind Method = 0x009
)
func (m Method) String() string {
switch m {
case MethodBinding:
return "binding"
case MethodAllocate:
return "allocate"
case MethodRefresh:
return "refresh"
case MethodSend:
return "send"
case MethodData:
return "data"
case MethodCreatePermission:
return "create permission"
case MethodChannelBind:
return "channel bind"
default:
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
}
}
// MessageType is STUN Message Type Field.
type MessageType struct {
Class MessageClass
Method Method
}
const (
methodABits = 0xf // 0b0000000000001111
methodBBits = 0x70 // 0b0000000001110000
methodDBits = 0xf80 // 0b0000111110000000
methodBShift = 1
methodDShift = 2
firstBit = 0x1
secondBit = 0x2
c0Bit = firstBit
c1Bit = secondBit
classC0Shift = 4
classC1Shift = 7
)
// Value returns bit representation of messageType.
func (t MessageType) Value() uint16 {
// 0 1
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// Figure 3: Format of STUN Message Type Field
// Warning: Abandon all hope ye who enter here.
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
m := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
m = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is fist bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
// Ct = C0 << 4 + C1 << 8.
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
// (see figure 3).
c := uint16(t.Class)
c0 := (c & c0Bit) << classC0Shift
c1 := (c & c1Bit) << classC1Shift
class := c0 + c1
return m + class
}
// ReadValue decodes uint16 into MessageType.
func (t *MessageType) ReadValue(v uint16) {
// Decoding class.
// We are taking first bit from v >> 4 and second from v >> 7.
c0 := (v >> classC0Shift) & c0Bit
c1 := (v >> classC1Shift) & c1Bit
class := c0 + c1
t.Class = MessageClass(class)
// Decoding method.
a := v & methodABits // A(M0-M3)
b := (v >> methodBShift) & methodBBits // B(M4-M6)
d := (v >> methodDShift) & methodDBits // D(M7-M11)
m := a + b + d
t.Method = Method(m)
}
func (t MessageType) String() string {
return fmt.Sprintf("%s %s", t.Method, t.Class)
}

13
padding.go Normal file
View File

@@ -0,0 +1,13 @@
package stun
const (
padding = 4
)
func nearestPaddedValueLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}

View File

@@ -6,7 +6,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"runtime" "sync/atomic"
"time" "time"
"github.com/ernado/stun" "github.com/ernado/stun"
@@ -17,13 +17,14 @@ var (
fmt.Sprintf("127.0.0.1:%d", stun.DefaultPort), fmt.Sprintf("127.0.0.1:%d", stun.DefaultPort),
"addr to attack", "addr to attack",
) )
readWorkers = flag.Int("read-workers", 1, "concurrent read workers")
writeWorkers = flag.Int("write-workers", 1, "concurrent write workers")
count int64 count int64
) )
func main() { func main() {
flag.Parse() flag.Parse()
runtime.GOMAXPROCS(2)
log.SetFlags(log.Lshortfile) log.SetFlags(log.Lshortfile)
go func() { go func() {
log.Println(http.ListenAndServe("localhost:6060", nil)) log.Println(http.ListenAndServe("localhost:6060", nil))
@@ -43,8 +44,9 @@ func main() {
}, },
TransactionID: stun.NewTransactionID(), TransactionID: stun.NewTransactionID(),
} }
m.AddRaw(stun.AttrSoftware, []byte("stun benchmark")) m.Add(stun.AttrSoftware, []byte("stun benchmark"))
m.Encode() m.Encode()
for i := 0; i < *readWorkers; i++ {
go func() { go func() {
mRec := stun.New() mRec := stun.New()
mRec.Raw = make([]byte, 1024) mRec.Raw = make([]byte, 1024)
@@ -58,15 +60,26 @@ func main() {
// if err := mRec.Decode(); err != nil { // if err := mRec.Decode(); err != nil {
// log.Fatalln("Decode:", err) // log.Fatalln("Decode:", err)
// } // }
count++ atomic.AddInt64(&count, 1)
if count%10000 == 0 { if c := atomic.LoadInt64(&count); c%10000 == 0 {
fmt.Printf("%d\n", count) fmt.Printf("%d\n", c)
elapsed := time.Since(start) elapsed := time.Since(start)
fmt.Println(float64(count)/elapsed.Seconds(), "per second") fmt.Println(float64(c)/elapsed.Seconds(), "per second")
} }
// mRec.Reset() // mRec.Reset()
} }
}() }()
}
for i := 1; i < *writeWorkers; i++ {
go func() {
for {
_, err := c.Write(m.Raw)
if err != nil {
log.Fatalln("write:", err)
}
}
}()
}
for { for {
_, err := c.Write(m.Raw) _, err := c.Write(m.Raw)
if err != nil { if err != nil {

View File

@@ -33,7 +33,7 @@ func TestClient_Do(t *testing.T) {
m := stun.New() m := stun.New()
m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
m.TransactionID = stun.NewTransactionID() m.TransactionID = stun.NewTransactionID()
m.Add(stun.NewSoftware("cydev/stun alpha")) stun.NewSoftware("cydev/stun alpha").AddTo(m)
m.WriteHeader() m.WriteHeader()
request := Request{ request := Request{
Target: "stun.l.google.com:19302", Target: "stun.l.google.com:19302",
@@ -43,11 +43,11 @@ func TestClient_Do(t *testing.T) {
if r.Message.TransactionID != m.TransactionID { if r.Message.TransactionID != m.TransactionID {
t.Error("transaction id messmatch") t.Error("transaction id messmatch")
} }
ip, port, err := r.Message.GetXORMappedAddress() addr := new(stun.XORMappedAddress)
if err != nil { if err := addr.GetFrom(m); err != nil {
t.Error(err) t.Error(err)
} }
log.Println("got", ip, port) log.Println("got", addr)
return nil return nil
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -190,7 +190,7 @@ func discover(c *cli.Context) error {
Class: stun.ClassRequest, Class: stun.ClassRequest,
}, },
} }
m.AddRaw(stun.AttrSoftware, software.Raw) m.Add(stun.AttrSoftware, software.Raw)
m.WriteHeader() m.WriteHeader()
request := Request{ request := Request{
@@ -200,15 +200,13 @@ func discover(c *cli.Context) error {
return DefaultClient.Do(request, func(r Response) error { return DefaultClient.Do(request, func(r Response) error {
var ( var (
ip net.IP
port int
err error err error
) )
ip, port, err = r.Message.GetXORMappedAddress() addr := new(stun.XORMappedAddress)
if err != nil { if err = addr.GetFrom(r.Message); err != nil {
return errors.Wrap(err, "failed to get ip") return errors.Wrap(err, "failed to get ip")
} }
fmt.Println(ip, port) fmt.Println(addr)
return nil return nil
}) })
} }

View File

@@ -23,7 +23,7 @@ func main() {
log.Fatalln("Unable to decode bas64 value:", err) log.Fatalln("Unable to decode bas64 value:", err)
} }
m := stun.New() m := stun.New()
if _, err = m.ReadBytes(data); err != nil { if _, err = m.Write(data); err != nil {
log.Fatalln("Unable to decode message:", err) log.Fatalln("Unable to decode message:", err)
} }
fmt.Println(m) fmt.Println(m)

540
stun.go
View File

@@ -16,542 +16,14 @@
// indications. In this specification, the terms STUN server and // indications. In this specification, the terms STUN server and
// server are synonymous. // server are synonymous.
// //
// Transport Address: The combination of an IP address and port number // Transport Address: The combination of an IP address and Port number
// (such as a UDP or TCP port number). // (such as a UDP or TCP Port number).
package stun package stun
import ( import "encoding/binary"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
)
var ( // bin is shorthand to binary.BigEndian.
// bin is shorthand to binary.BigEndian. var bin = binary.BigEndian
bin = binary.BigEndian
)
// DefaultPort is IANA assigned port for "stun" protocol. // DefaultPort is IANA assigned Port for "stun" protocol.
const DefaultPort = 3478 const DefaultPort = 3478
const (
// magicCookie is fixed value that aids in distinguishing STUN packets
// from packets of other protocols when STUN is multiplexed with those
// other protocols on the same port.
//
// The magic cookie field MUST contain the fixed value 0x2112A442 in
// network byte order.
//
// Defined in "STUN Message Structure", section 6.
magicCookie = 0x2112A442
)
const transactionIDSize = 12 // 96 bit
// Message represents a single STUN packet. It uses aggressive internal
// buffering to enable zero-allocation encoding and decoding,
// so there are some usage constraints:
//
// * Message and its fields is valid only until AcquireMessage call.
type Message struct {
Type MessageType
Length uint32 // len(Raw) not including header
TransactionID [transactionIDSize]byte
Attributes Attributes
Raw []byte
}
// CopyTo copies all m to c.
func (m Message) CopyTo(c *Message) {
c.Type = m.Type
c.Length = m.Length
copy(c.TransactionID[:], m.TransactionID[:])
buf := m.Raw[:int(m.Length)+messageHeaderSize]
c.Raw = c.Raw[:0]
c.Raw = append(c.Raw, buf...)
buf = c.Raw[messageHeaderSize:]
for _, a := range m.Attributes {
buf = buf[attributeHeaderSize:]
c.Attributes = append(c.Attributes, RawAttribute{
Length: a.Length,
Type: a.Type,
Value: buf[:int(a.Length)],
})
buf = buf[int(a.Length):]
}
}
func (m Message) String() string {
return fmt.Sprintf("%s (l=%d,%d/%d) attr[%d] id[%s]",
m.Type,
m.Length,
len(m.Raw),
cap(m.Raw),
len(m.Attributes),
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
)
}
// NewTransactionID returns new random transaction ID using crypto/rand
// as source.
func NewTransactionID() (b [transactionIDSize]byte) {
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return b
}
// defaults for pool.
const (
defaultMessageBufferCapacity = 120
)
// New returns *Message with allocated Raw.
func New() *Message {
return &Message{
Raw: make([]byte, messageHeaderSize, defaultMessageBufferCapacity),
}
}
// Reset resets Message length, attributes and underlying buffer, as well as
// setting readOnly flag to false.
func (m *Message) Reset() {
m.Raw = m.Raw[:0]
m.Length = 0
m.Attributes = m.Attributes[:0]
}
// grow ensures that internal buffer will fit v more bytes and
// increases it capacity if necessary.
func (m *Message) grow(v int) {
// Not performing any optimizations here
// (e.g. preallocate len(buf) * 2 to reduce allocations)
// because they are already done by []byte implementation.
n := len(m.Raw) + v
for cap(m.Raw) < n {
m.Raw = append(m.Raw, 0)
}
m.Raw = m.Raw[:n]
}
// Add adds AttrEncoder to message, calling Encode method.
func (m *Message) Add(a AttrEncoder) error {
var (
err error
t AttrType
)
initial := len(m.Raw)
for i := 0; i < attributeHeaderSize; i++ {
m.Raw = append(m.Raw, 0)
}
start := len(m.Raw)
t, m.Raw, err = a.Encode(m.Raw, m)
if err != nil {
m.Raw = m.Raw[:initial]
return err
}
m.AddRaw(t, m.Raw[start:])
return nil
}
// AddRaw appends new attribute to message. Not goroutine-safe.
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func (m *Message) AddRaw(t AttrType, v []byte) {
// CPU: suboptimal;
// allocating memory for TLV (type-length-value), where
// type-length is attribute header.
// m.buf.B[0:20] is reserved by header
// internal buffer will look like:
// [0:20] <- message header
// [20:20+m.Length] <- added message attributes
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated slice for new TLV
// [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)]
// T L V
allocSize := attributeHeaderSize + len(v) // len(TLV)
first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number
m.grow(last) // growing cap(b) to fit TLV
m.Raw = m.Raw[:last] // now len(b) = last
m.Length += uint32(allocSize) // rendering length change
// subslicing internal buffer to simplify encoding
buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for value
attr := RawAttribute{
Type: t,
Value: value,
Length: uint16(len(v)),
}
// encoding attribute TLV to internal buffer
bin.PutUint16(buf[0:2], attr.Type.Value()) // type
bin.PutUint16(buf[2:4], attr.Length) // length
copy(value, v) // value
if attr.Length%padding != 0 {
// performing padding
bytesToAdd := nearestLength(len(v)) - len(v)
last += bytesToAdd
m.grow(last)
// setting all padding bytes to zero
// to prevent data leak from previous
// data in next bytesToAdd bytes
buf = m.Raw[last-bytesToAdd : last]
for i := range buf {
buf[i] = 0
}
m.Raw = m.Raw[:last] // increasing buffer length
m.Length += uint32(bytesToAdd) // rendering length change
}
m.Attributes = append(m.Attributes, attr)
}
const (
padding = 4
)
func nearestLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}
// Equal returns true if Message b equals to m.
func (m *Message) Equal(b *Message) bool {
if m.Type != b.Type {
return false
}
if m.TransactionID != b.TransactionID {
return false
}
if m.Length != b.Length {
return false
}
for _, a := range m.Attributes {
aB, ok := b.Attributes.Get(a.Type)
if !ok {
return false
}
if !aB.Equal(a) {
return false
}
}
return true
}
// WriteLength writes m.Length to m.Raw.
func (m *Message) WriteLength() { bin.PutUint16(m.Raw[2:4], uint16(m.Length)) }
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
func (m *Message) WriteHeader() {
// encoding header
if len(m.Raw) < messageHeaderSize {
m.grow(messageHeaderSize)
}
bin.PutUint16(m.Raw[0:2], m.Type.Value())
bin.PutUint32(m.Raw[4:8], magicCookie)
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:])
// attributes are already encoded
// writing length as size, in bytes, not including the 20-byte STUN header.
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize))
}
// WriteAttributes encodes all m.Attributes to m.
func (m *Message) WriteAttributes() {
for _, a := range m.Attributes {
m.AddRaw(a.Type, a.Value)
}
}
// Encode writes m into Raw.
func (m *Message) Encode() {
m.Raw = m.Raw[:0]
m.WriteHeader()
m.WriteAttributes()
}
// WriteTo implements WriterTo.
func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw)
return int64(n), err
}
// WriteToConn writes a packet with message to addr, using c.
// Deprecated; non-idiomatic.
func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) {
return c.WriteTo(m.Raw, addr)
}
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
//
// Can return *DecodeErr while decoding.
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
tBuf := m.Raw[:cap(m.Raw)]
var (
n int
err error
)
if n, err = r.Read(tBuf); err != nil {
return int64(n), err
}
m.Raw = tBuf[:n]
return int64(n), m.Decode()
}
func newAttrDecodeErr(children, message string) *DecodeErr {
return newDecodeErr("attribute", children, message)
}
// IsMessage returns true if b looks like STUN message.
// Useful for multiplexing. IsMessage does not guarantee
// that decoding will be successful.
func IsMessage(b []byte) bool {
// b should be at least messageHeaderSize bytes and
// contain correct magicCookie.
return len(b) >= messageHeaderSize &&
binary.BigEndian.Uint32(b[4:8]) == magicCookie
}
const (
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header.
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
)
// Decode decodes m.Raw into m.
func (m *Message) Decode() error {
// decoding message header
buf := m.Raw
if len(buf) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
cookie = binary.BigEndian.Uint32(buf[4:8])
fullSize = messageHeaderSize + size
)
if cookie != magicCookie {
msg := fmt.Sprintf(
"%x is invalid magic cookie (should be %x)",
cookie, magicCookie,
)
return newDecodeErr("message", "cookie", msg)
}
if len(buf) < fullSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected message size)",
len(buf), fullSize,
)
return newAttrDecodeErr("message", msg)
}
// saving header data
m.Type.ReadValue(t)
m.Length = uint32(size)
copy(m.TransactionID[:], buf[8:messageHeaderSize])
var (
offset = 0
b = buf[messageHeaderSize:fullSize]
)
for offset < size {
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected header size)",
len(b), attributeHeaderSize,
)
return newAttrDecodeErr("header", msg)
}
var (
a = RawAttribute{
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes
}
aL = int(a.Length) // attribute length
aBuffL = nearestLength(aL) // expected buffer length (with padding)
)
b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize
if len(b) < aBuffL { // checking size
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected value size)",
len(b), aBuffL,
)
return newAttrDecodeErr("value", msg)
}
a.Value = b[:aL]
offset += aBuffL
b = b[aBuffL:]
m.Attributes = append(m.Attributes, a)
}
return nil
}
// ReadBytes decodes message and return error if any.
//
// Any error is unrecoverable, but message could be partially decoded.
// Deprecated: use m.Decode.
func (m *Message) ReadBytes(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode()
}
const (
attributeHeaderSize = 4
messageHeaderSize = 20
)
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
const MaxPacketSize = 2048
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte
// Possible values for message class in STUN Message Type.
const (
ClassRequest MessageClass = 0x00 // 0b00
ClassIndication MessageClass = 0x01 // 0b01
ClassSuccessResponse MessageClass = 0x02 // 0b10
ClassErrorResponse MessageClass = 0x03 // 0b11
)
func (c MessageClass) String() string {
switch c {
case ClassRequest:
return "request"
case ClassIndication:
return "indication"
case ClassSuccessResponse:
return "success response"
case ClassErrorResponse:
return "error response"
default:
panic("unknown message class")
}
}
// Method is uint16 representation of 12-bit STUN method.
type Method uint16
// Possible methods for STUN Message.
const (
MethodBinding Method = 0x001
MethodAllocate Method = 0x003
MethodRefresh Method = 0x004
MethodSend Method = 0x006
MethodData Method = 0x007
MethodCreatePermission Method = 0x008
MethodChannelBind Method = 0x009
)
func (m Method) String() string {
switch m {
case MethodBinding:
return "binding"
case MethodAllocate:
return "allocate"
case MethodRefresh:
return "refresh"
case MethodSend:
return "send"
case MethodData:
return "data"
case MethodCreatePermission:
return "create permission"
case MethodChannelBind:
return "channel bind"
default:
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
}
}
// MessageType is STUN Message Type Field.
type MessageType struct {
Class MessageClass
Method Method
}
const (
methodABits = 0xf // 0b0000000000001111
methodBBits = 0x70 // 0b0000000001110000
methodDBits = 0xf80 // 0b0000111110000000
methodBShift = 1
methodDShift = 2
firstBit = 0x1
secondBit = 0x2
c0Bit = firstBit
c1Bit = secondBit
classC0Shift = 4
classC1Shift = 7
)
// Value returns bit representation of messageType.
func (t MessageType) Value() uint16 {
// 0 1
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// Figure 3: Format of STUN Message Type Field
// warning: Abandon all hope ye who enter here.
// splitting M into A(M0-M3), B(M4-M6), D(M7-M11)
m := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit)
m = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is fist bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
// Ct = C0 << 4 + C1 << 8.
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
// (see figure 3).
c := uint16(t.Class)
c0 := (c & c0Bit) << classC0Shift
c1 := (c & c1Bit) << classC1Shift
class := c0 + c1
return m + class
}
// ReadValue decodes uint16 into MessageType.
func (t *MessageType) ReadValue(v uint16) {
// decoding class
// we are taking first bit from v >> 4 and second from v >> 7.
c0 := (v >> classC0Shift) & c0Bit
c1 := (v >> classC1Shift) & c1Bit
class := c0 + c1
t.Class = MessageClass(class)
// decoding method
a := v & methodABits // A(M0-M3)
b := (v >> methodBShift) & methodBBits // B(M4-M6)
d := (v >> methodDShift) & methodDBits // D(M7-M11)
m := a + b + d
t.Method = Method(m)
}
func (t MessageType) String() string {
return fmt.Sprintf("%s %s", t.Method, t.Class)
}

View File

@@ -9,14 +9,13 @@ import (
"hash/crc64" "hash/crc64"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"net"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -33,24 +32,11 @@ func (m *Message) reader() *bytes.Reader {
return bytes.NewReader(m.Raw) return bytes.NewReader(m.Raw)
} }
func TestMessageCopy(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
mCopy := New()
m.CopyTo(mCopy)
if !mCopy.Equal(m) {
t.Error(mCopy, "!=", m)
}
}
func TestMessageBuffer(t *testing.T) { func TestMessageBuffer(t *testing.T) {
m := New() m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID() m.TransactionID = NewTransactionID()
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader() m.WriteHeader()
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil { if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
@@ -69,7 +55,7 @@ func BenchmarkMessage_Write(b *testing.B) {
transactionID := NewTransactionID() transactionID := NewTransactionID()
m := New() m := New()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.AddRaw(AttrErrorCode, attributeValue) m.Add(AttrErrorCode, attributeValue)
m.TransactionID = transactionID m.TransactionID = transactionID
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.WriteHeader() m.WriteHeader()
@@ -133,6 +119,25 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
} }
} }
func TestMessage_WriteTo(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil {
t.Fatal(err)
}
mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err)
}
if !mDecoded.Equal(m) {
t.Error(mDecoded, "!", m)
}
}
func TestMessage_Cookie(t *testing.T) { func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20) buf := make([]byte, 20)
mDecoded := New() mDecoded := New()
@@ -156,11 +161,11 @@ func TestMessage_BadLength(t *testing.T) {
Length: 4, Length: 4,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
} }
m.AddRaw(0x1, []byte{1, 2}) m.Add(0x1, []byte{1, 2})
m.WriteHeader() m.WriteHeader()
m.Raw[20+3] = 10 // set attr length = 10 m.Raw[20+3] = 10 // set attr length = 10
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadBytes(m.Raw); err == nil { if _, err := mDecoded.Write(m.Raw); err == nil {
t.Error("should error") t.Error("should error")
} }
} }
@@ -295,7 +300,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
mRec := New() mRec := New()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := mRec.ReadBytes(m.Raw); err != nil { if _, err := mRec.Write(m.Raw); err != nil {
b.Fatal(err) b.Fatal(err)
} }
mRec.Reset() mRec.Reset()
@@ -439,9 +444,7 @@ func TestMessage_String(t *testing.T) {
func TestIsMessage(t *testing.T) { func TestIsMessage(t *testing.T) {
m := New() m := New()
m.Add(&Software{ NewSoftware("software").AddTo(m)
Raw: []byte("test"),
})
m.WriteHeader() m.WriteHeader()
var tt = [...]struct { var tt = [...]struct {
@@ -468,7 +471,7 @@ func BenchmarkIsMessage(b *testing.B) {
m := New() m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID() m.TransactionID = NewTransactionID()
m.AddSoftware("cydev/stun test") NewSoftware("cydev/stun test").AddTo(m)
m.WriteHeader() m.WriteHeader()
b.SetBytes(int64(messageHeaderSize)) b.SetBytes(int64(messageHeaderSize))
@@ -502,13 +505,13 @@ func loadData(tb testing.TB, name string) []byte {
func TestExampleChrome(t *testing.T) { func TestExampleChrome(t *testing.T) {
buf := loadData(t, "ex1_chrome.stun") buf := loadData(t, "ex1_chrome.stun")
m := New() m := New()
_, err := m.ReadBytes(buf) _, err := m.Write(buf)
if err != nil { if err != nil {
t.Errorf("Failed to parse ex1_chrome: %s", err) t.Errorf("Failed to parse ex1_chrome: %s", err)
} }
} }
func TestNearestLen(t *testing.T) { func TestPadding(t *testing.T) {
tt := []struct { tt := []struct {
in, out int in, out int
}{ }{
@@ -525,7 +528,7 @@ func TestNearestLen(t *testing.T) {
{40, 40}, // 10 {40, 40}, // 10
} }
for i, c := range tt { for i, c := range tt {
if got := nearestLength(c.in); got != c.out { if got := nearestPaddedValueLength(c.in); got != c.out {
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
i, c.in, got, c.out, i, c.in, got, c.out,
) )
@@ -562,7 +565,7 @@ func TestMessageFromBrowsers(t *testing.T) {
if b != crc64.Checksum(data, crcTable) { if b != crc64.Checksum(data, crcTable) {
t.Error("crc64 check failed for ", line[1]) t.Error("crc64 check failed for ", line[1])
} }
if _, err = m.ReadBytes(data); err != nil { if _, err = m.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err) t.Error("failed to decode ", line[1], " as message: ", err)
} }
m.Reset() m.Reset()
@@ -583,20 +586,29 @@ func BenchmarkNewTransactionID(b *testing.B) {
} }
} }
func BenchmarkMessage_NewTransactionID(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
if err := m.NewTransactionID(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkMessageFull(b *testing.B) { func BenchmarkMessageFull(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m := new(Message) m := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
fingerprint := new(Fingerprint)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.Add(addr) addAttr(b, m, addr)
m.Add(s) addAttr(b, m, s)
m.WriteAttributes() m.WriteAttributes()
m.WriteHeader() m.WriteHeader()
fingerprint.AddTo(m) Fingerprint.AddTo(m)
m.WriteHeader() m.WriteHeader()
m.Reset() m.Reset()
} }
@@ -607,19 +619,26 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
m := new(Message) m := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
t, v, _ := addr.Encode(nil, m)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.AddRaw(AttrSoftware, s.Raw) if err := addr.AddTo(m); err != nil {
m.AddRaw(t, v) b.Fatal(err)
m.WriteAttributes() }
if err := s.AddTo(m); err != nil {
b.Fatal(err)
}
m.WriteHeader()
m.Reset() m.Reset()
} }
} }
func addAttr(t testing.TB, m *Message, a AttrEncoder) { type attributeEncoder interface {
if err := m.Add(a); err != nil { AddTo(m *Message) error
}
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
if err := a.AddTo(m); err != nil {
t.Error(err) t.Error(err)
} }
} }
@@ -629,14 +648,13 @@ func BenchmarkFingerprint_AddTo(b *testing.B) {
m := new(Message) m := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
addAttr(b, m, addr) addAttr(b, m, addr)
addAttr(b, m, s) addAttr(b, m, s)
fingerprint := new(Fingerprint)
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
fingerprint.AddTo(m) Fingerprint.AddTo(m)
m.WriteLength() m.WriteLength()
m.Length -= attributeHeaderSize + fingerprintSize m.Length -= attributeHeaderSize + fingerprintSize
m.Raw = m.Raw[:m.Length+messageHeaderSize] m.Raw = m.Raw[:m.Length+messageHeaderSize]
@@ -648,12 +666,10 @@ func TestFingerprint_Check(t *testing.T) {
m := new(Message) m := new(Message)
addAttr(t, m, NewSoftware("software")) addAttr(t, m, NewSoftware("software"))
m.WriteHeader() m.WriteHeader()
fInit := new(Fingerprint) Fingerprint.AddTo(m)
fInit.AddTo(m)
m.WriteHeader() m.WriteHeader()
f := new(Fingerprint) if err := Fingerprint.Check(m); err != nil {
if err := f.Check(m); err != nil { t.Error(err)
t.Errorf("decoded: %d, encoded: %d, err: %s", f.Value, fInit.Value, err)
} }
} }
@@ -662,17 +678,16 @@ func BenchmarkFingerprint_Check(b *testing.B) {
m := new(Message) m := new(Message)
s := NewSoftware("software") s := NewSoftware("software")
addr := &XORMappedAddress{ addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
addAttr(b, m, addr) addAttr(b, m, addr)
addAttr(b, m, s) addAttr(b, m, s)
fingerprint := new(Fingerprint)
m.WriteHeader() m.WriteHeader()
fingerprint.AddTo(m) Fingerprint.AddTo(m)
m.WriteHeader() m.WriteHeader()
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := fingerprint.Check(m); err != nil { if err := Fingerprint.Check(m); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }

View File

@@ -5,9 +5,11 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"runtime" "net/http"
"strings" "strings"
_ "net/http/pprof"
"github.com/ernado/stun" "github.com/ernado/stun"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -16,6 +18,8 @@ import (
var ( var (
network = flag.String("net", "udp", "network to listen") network = flag.String("net", "udp", "network to listen")
address = flag.String("addr", "0.0.0.0:3478", "address to listen") address = flag.String("addr", "0.0.0.0:3478", "address to listen")
workers = flag.Int("workers", 1, "workers to start")
profile = flag.Bool("profile", false, "profile")
) )
// Server is RFC 5389 basic server implementation. // Server is RFC 5389 basic server implementation.
@@ -80,7 +84,7 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
if !stun.IsMessage(b) { if !stun.IsMessage(b) {
return errNotSTUNMessage return errNotSTUNMessage
} }
if _, err := req.ReadBytes(b); err != nil { if _, err := req.Write(b); err != nil {
return errors.Wrap(err, "failed to read message") return errors.Wrap(err, "failed to read message")
} }
res.TransactionID = req.TransactionID res.TransactionID = req.TransactionID
@@ -99,8 +103,12 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
default: default:
panic(fmt.Sprintf("unknown addr: %v", addr)) panic(fmt.Sprintf("unknown addr: %v", addr))
} }
res.AddXORMappedAddress(ip, port) xma := &stun.XORMappedAddress{
res.AddRaw(stun.AttrSoftware, software.Raw) IP: ip,
Port: port,
}
xma.AddTo(res)
software.AddTo(res)
res.WriteHeader() res.WriteHeader()
return nil return nil
} }
@@ -116,8 +124,8 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
return nil return nil
} }
// s.logger().Printf("read %d bytes from %s", n, addr) // s.logger().Printf("read %d bytes from %s", n, addr)
if _, err = req.ReadBytes(buf[:n]); err != nil { if _, err = req.Write(buf[:n]); err != nil {
s.logger().Printf("ReadBytes: %v", err) s.logger().Printf("Write: %v", err)
return err return err
} }
if err = basicProcess(addr, buf[:n], req, res); err != nil { if err = basicProcess(addr, buf[:n], req, res); err != nil {
@@ -156,8 +164,10 @@ func ListenUDPAndServe(serverNet, laddr string) error {
if err != nil { if err != nil {
return err return err
} }
s := &Server{} for i := 1; i < *workers; i++ {
return s.Serve(c) go new(Server).Serve(c)
}
return new(Server).Serve(c)
} }
func normalize(address string) string { func normalize(address string) string {
@@ -172,7 +182,11 @@ func normalize(address string) string {
func main() { func main() {
flag.Parse() flag.Parse()
runtime.GOMAXPROCS(1) if *profile {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
}
switch *network { switch *network {
case "udp": case "udp":
normalized := normalize(*address) normalized := normalize(*address)

View File

@@ -17,7 +17,7 @@ func BenchmarkBasicProcess(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
m.TransactionID = stun.NewTransactionID() m.TransactionID = stun.NewTransactionID()
m.Add(stun.NewSoftware("some software")) stun.NewSoftware("some software").AddTo(m)
m.WriteHeader() m.WriteHeader()
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

2
xor.go
View File

@@ -34,7 +34,7 @@ func fastXORBytes(dst, a, b []byte) int {
} }
} }
for i := (n - n%wordSize); i < n; i++ { for i := n - n%wordSize; i < n; i++ {
dst[i] = a[i] ^ b[i] dst[i] = a[i] ^ b[i]
} }