all: refactor attributes

This commit is contained in:
Aleksandr Razumov
2017-02-11 05:23:49 +03:00
parent 3ec51bf758
commit fba9f6b02d
21 changed files with 1048 additions and 1107 deletions

View File

@@ -1,4 +1,4 @@
Copyright (c) 2016 Aleksandr Razumov, Cydev. All Rigths Reserved.
Copyright (c) 2016-2017 Aleksandr Razumov, Cydev. All Rigths Reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are

View File

@@ -2,39 +2,88 @@ package stun
// ErrorCodeAttribute represents ERROR-CODE attribute.
type ErrorCodeAttribute struct {
Code int
Code ErrorCode
Reason []byte
}
// constants for ERROR-CODE encoding.
const (
errorCodeReasonStart = 4
errorCodeClassByte = 2
errorCodeNumberByte = 3
errorCodeReasonMaxB = 763
errorCodeModulo = 100
)
// AddTo adds ERROR-CODE to m.
func (c *ErrorCodeAttribute) AddTo(m *Message) error {
value := make([]byte,
errorCodeReasonStart, errorCodeReasonMaxB,
)
number := byte(c.Code % errorCodeModulo) // error code modulo 100
class := byte(c.Code / errorCodeModulo) // hundred digit
value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number
value = append(value, c.Reason...)
m.Add(AttrErrorCode, value)
return nil
}
// GetFrom decodes ERROR-CODE from m.
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
v, err := m.Get(AttrErrorCode)
if err != nil {
return err
}
var (
class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte])
code = int(class*errorCodeModulo + number)
reason = v[errorCodeReasonStart:]
)
c.Code = ErrorCode(code)
c.Reason = reason
return nil
}
// ErrorCode is code for ERROR-CODE attribute.
type ErrorCode int
// ErrNoDefaultReason means that default reason for provided error code
// is not defined in RFC.
const ErrNoDefaultReason Error = "No default reason for ErrorCode"
// AddTo adds ERROR-CODE with default reason to m. If there
// is no default reason, returns ErrNoDefaultReason.
func (c ErrorCode) AddTo(m *Message) error {
reason := errorReasons[c]
if reason == nil {
return ErrNoDefaultReason
}
a := &ErrorCodeAttribute{
Code: c,
Reason: reason,
}
return a.AddTo(m)
}
// Possible error codes.
const (
CodeTryAlternate = 300
CodeBadRequest = 400
CodeUnauthorised = 401
CodeUnknownAttribute = 420
CodeStaleNonce = 428
CodeRoleConflict = 478
CodeServerError = 500
CodeTryAlternate ErrorCode = 300
CodeBadRequest ErrorCode = 400
CodeUnauthorised ErrorCode = 401
CodeUnknownAttribute ErrorCode = 420
CodeStaleNonce ErrorCode = 428
CodeRoleConflict ErrorCode = 478
CodeServerError ErrorCode = 500
)
var errorReasons = map[int]string{
CodeTryAlternate: "Try Alternate",
CodeBadRequest: "Bad Request",
CodeUnauthorised: "Unauthorised",
CodeUnknownAttribute: "Unknown Attribute",
CodeStaleNonce: "Stale Nonce",
CodeServerError: "Server Error",
CodeRoleConflict: "Role Conflict",
}
// Reason returns recommended reason string.
func (c ErrorCode) Reason() string {
reason, ok := errorReasons[int(c)]
if !ok {
return "Unknown Error"
}
return reason
var errorReasons = map[ErrorCode][]byte{
CodeTryAlternate: []byte("Try Alternate"),
CodeBadRequest: []byte("Bad Request"),
CodeUnauthorised: []byte("Unauthorised"),
CodeUnknownAttribute: []byte("Unknown Attribute"),
CodeStaleNonce: []byte("Stale Nonce"),
CodeServerError: []byte("Server Error"),
CodeRoleConflict: []byte("Role Conflict"),
}

View File

@@ -1,26 +0,0 @@
package stun
import "testing"
func TestErrorCode_Reason(t *testing.T) {
codes := [...]ErrorCode{
CodeTryAlternate,
CodeBadRequest,
CodeUnauthorised,
CodeUnknownAttribute,
CodeStaleNonce,
CodeRoleConflict,
CodeServerError,
}
for _, code := range codes {
if code.Reason() == "Unknown Error" {
t.Error(code, "should not be unknown")
}
if len(code.Reason()) == 0 {
t.Error(code, "should not be blank")
}
}
if ErrorCode(999).Reason() != "Unknown Error" {
t.Error("999 error should be Unknown")
}
}

68
attribute_fingerprint.go Normal file
View File

@@ -0,0 +1,68 @@
package stun
import (
"fmt"
"hash/crc32"
)
// FingerprintAttr represent FINGERPRINT attribute.
type FingerprintAttr struct{}
// CRCMismatch represents CRC check error.
type CRCMismatch struct {
Expected uint32
Actual uint32
}
func (m CRCMismatch) Error() string {
return fmt.Sprintf("CRC mismatch: %x (expected) != %x (actual)",
m.Expected,
m.Actual,
)
}
// Fingerprint is shorthand for FingerprintAttr.
var Fingerprint = &FingerprintAttr{}
const (
fingerprintXORValue uint32 = 0x5354554e
fingerprintSize = 4 // 32 bit
)
// FingerprintValue returns CRC32 of m XOR-ed by 0x5354554e.
func FingerprintValue(b []byte) uint32 {
return crc32.ChecksumIEEE(b) ^ fingerprintXORValue // XOR
}
// AddTo adds fingerprint to message.
func (FingerprintAttr) AddTo(m *Message) error {
l := m.Length
// length in header should include size of fingerprint attribute
m.Length += fingerprintSize + attributeHeaderSize // increasing length
m.WriteLength() // writing Length to Raw
b := make([]byte, fingerprintSize)
v := FingerprintValue(m.Raw)
bin.PutUint32(b, v)
m.Length = l
m.Add(AttrFingerprint, b)
return nil
}
// Check reads fingerprint value from m and checks it, returning error if any.
// Can return *DecodeErr, ErrAttributeNotFound and *CRCMismatch.
func (FingerprintAttr) Check(m *Message) error {
v, err := m.Get(AttrFingerprint)
if err != nil {
return err
}
if len(v) != fingerprintSize {
return newDecodeErr("message", "fingerprint", "bad length")
}
val := bin.Uint32(v)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := FingerprintValue(m.Raw[:attrStart])
if expected != val {
return &CRCMismatch{Expected: expected, Actual: val}
}
return nil
}

31
attribute_software.go Normal file
View File

@@ -0,0 +1,31 @@
package stun
// Software is SOFTWARE attribute.
type Software struct {
Raw []byte
}
func (s *Software) String() string {
return string(s.Raw)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) *Software {
return &Software{Raw: []byte(software)}
}
// AddTo adds Software attribute to m.
func (s *Software) AddTo(m *Message) error {
m.Add(AttrSoftware, m.Raw)
return nil
}
// GetFrom decodes Software from m.
func (s *Software) GetFrom(m *Message) error {
v, err := m.Get(AttrSoftware)
if err != nil {
return err
}
s.Raw = v
return nil
}

118
attribute_xoraddr.go Normal file
View File

@@ -0,0 +1,118 @@
package stun
import (
"fmt"
"net"
)
const (
familyIPv4 byte = 0x01
familyIPv6 byte = 0x02
)
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
type XORMappedAddress struct {
IP net.IP
Port int
}
func (a XORMappedAddress) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// Is p all zeros?
func isZeros(p net.IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
const ErrBadIPLength Error = "invalid length if IP value"
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
// if len(a.IP) is invalid.
func (a *XORMappedAddress) AddTo(m *Message) error {
var (
family = familyIPv4
ip = a.IP
)
if len(a.IP) == net.IPv6len {
// Optimized for performance. See net.IP.To4 method.
if isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
ip = ip[12:16] // like in ip.To4()
} else {
family = familyIPv6
}
} else if len(ip) != net.IPv4len {
return ErrBadIPLength
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
bin.PutUint32(xorValue[0:4], magicCookie)
bin.PutUint16(value[0:2], uint16(family))
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
xorBytes(value[4:4+len(ip)], ip, xorValue)
m.Add(AttrXORMappedAddress, value[:4+len(ip)])
return nil
}
// GetFrom decodes XOR-MAPPED-ADDRESS attribute in message and returns
// error if any. While decoding, a.IP is reused if possible and can be
// rendered to invalid state (e.g. if a.IP was set to IPv6 and then
// IPv4 value were decoded into it), be careful.
//
// Example:
//
// expectedIP := net.ParseIP("213.141.156.236")
// expectedIP.String() // 213.141.156.236, 16 bytes, first 12 of them are zeroes
// expectedPort := 21254
// addr := &XORMappedAddress{
// IP: expectedIP,
// Port: expectedPort,
// }
// // addr were added to message that is decoded as newMessage
// // ...
//
// addr.GetFrom(newMessage)
// addr.IP.String() // 213.141.156.236, net.IPv4Len
// expectedIP.String() // d58d:9cec::ffff:d58d:9cec, 16 bytes, first 4 are IPv4
// // now we have len(expectedIP) = 16 and len(addr.IP) = 4.
func (a *XORMappedAddress) GetFrom(m *Message) error {
v, err := m.Get(AttrXORMappedAddress)
if err != nil {
return err
}
family := byte(bin.Uint16(v[0:2]))
if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == familyIPv6 {
ipLen = net.IPv6len
}
// Ensuring len(a.IP) == ipLen and reusing a.IP.
if len(a.IP) < ipLen {
a.IP = a.IP[:cap(a.IP)]
for len(a.IP) < ipLen {
a.IP = append(a.IP, 0)
}
}
a.IP = a.IP[:ipLen]
for i := range a.IP {
a.IP[i] = 0
}
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 4+transactionIDSize)
bin.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(a.IP, v[4:], xorValue)
return nil
}

View File

@@ -1,23 +1,10 @@
package stun
import (
"encoding/binary"
"fmt"
"hash/crc32"
"net"
"strconv"
)
// AttrWriter wraps AddRaw method.
type AttrWriter interface {
AddRaw(t AttrType, v []byte)
}
// AttrEncoder wraps Encode method.
type AttrEncoder interface {
Encode(b []byte, m *Message) (AttrType, []byte, error)
}
// Attributes is list of message attributes.
type Attributes []RawAttribute
@@ -77,7 +64,7 @@ const (
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
)
// Attributes from An Origin RawAttribute for the STUN Protocol.
// Attributes from An Origin Attribute for the STUN Protocol.
const (
AttrOrigin AttrType = 0x802F
)
@@ -118,7 +105,7 @@ var attrNames = map[AttrType]string{
func (t AttrType) String() string {
s, ok := attrNames[t]
if !ok {
// just return hex representation of unknown attribute type
// Just return hex representation of unknown attribute type.
return "0x" + strconv.FormatUint(uint64(t), 16)
}
return s
@@ -131,21 +118,17 @@ func (t AttrType) String() string {
// don't understand, but cannot successfully process a message if it
// contains comprehension-required attributes that are not
// understood.
//
// TODO(ar): Decide to use pointer or non-pointer RawAttribute.
type RawAttribute struct {
Type AttrType
Length uint16 // ignored while encoding
Value []byte
}
// Encode implements AttrEncoder.
func (a *RawAttribute) Encode(m *Message) ([]byte, error) {
return m.Raw, nil
}
// Decode implements AttrDecoder.
func (a *RawAttribute) Decode(v []byte, m *Message) error {
a.Value = v
a.Length = uint16(len(v))
// AddTo adds RawAttribute to m.
func (a *RawAttribute) AddTo(m *Message) error {
m.Add(a.Type, m.Raw)
return nil
}
@@ -172,319 +155,17 @@ func (a RawAttribute) String() string {
return fmt.Sprintf("%s: %x", a.Type, a.Value)
}
// getAttrValue returns byte slice that represents attribute value,
// if there is no value attribute with shuck type,
// ErrAttributeNotFound means that attribute with provided attribute
// type does not exist in message.
const ErrAttributeNotFound Error = "Attribute not found"
// Get returns byte slice that represents attribute value,
// if there is no attribute with such type,
// ErrAttributeNotFound is returned.
func (m *Message) getAttrValue(t AttrType) ([]byte, error) {
func (m *Message) Get(t AttrType) ([]byte, error) {
v, ok := m.Attributes.Get(t)
if !ok {
return nil, ErrAttributeNotFound
}
return v.Value, nil
}
// AddSoftware adds SOFTWARE attribute with value from string.
// Deprecated: use AddRaw.
func (m *Message) AddSoftware(software string) {
m.AddRaw(AttrSoftware, []byte(software))
}
// Set sets the value of attribute if it presents.
func (m *Message) Set(a AttrEncoder) error {
var (
v []byte
err error
t AttrType
)
t, v, err = a.Encode(v, m)
if err != nil {
return err
}
buf, err := m.getAttrValue(t)
if err != nil {
return err
}
if len(v) != len(buf) {
return ErrBadSetLength
}
copy(buf, v)
return nil
}
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
// If not found, returns nil.
func (m *Message) GetSoftwareBytes() []byte {
v, ok := m.Attributes.Get(AttrSoftware)
if !ok {
return nil
}
return v.Value
}
// GetSoftware returns SOFTWARE attribute value in string.
// If not found, returns blank string.
// Deprecated.
func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) }
// Address family values.
const (
FamilyIPv4 byte = 0x01
FamilyIPv6 byte = 0x02
)
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
type XORMappedAddress struct {
ip net.IP
port int
}
// Encode implements AttrEncoder.
func (a *XORMappedAddress) Encode(buf []byte, m *Message) (AttrType, []byte, error) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
family := FamilyIPv6
ip := a.ip
port := a.port
if ipV4 := ip.To4(); ipV4 != nil {
ip = ipV4
family = FamilyIPv4
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
port ^= magicCookie >> 16
binary.BigEndian.PutUint16(value[0:2], uint16(family))
binary.BigEndian.PutUint16(value[2:4], uint16(port))
xorBytes(value[4:4+len(ip)], ip, xorValue)
buf = append(buf, value...)
return AttrXORMappedAddress, buf, nil
}
// Decode implements AttrDecoder.
// TODO(ar): fix signature.
func (a *XORMappedAddress) Decode(v []byte, m *Message) error {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
v, err := m.getAttrValue(AttrXORMappedAddress)
if err != nil {
return err
}
family := byte(binary.BigEndian.Uint16(v[0:2]))
if family != FamilyIPv6 && family != FamilyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == FamilyIPv6 {
ipLen = net.IPv6len
}
ip := net.IP(m.allocBuffer(ipLen))
a.port = int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 128)
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(ip, v[4:], xorValue)
a.ip = ip
return nil
}
// AddXORMappedAddress adds XOR MAPPED ADDRESS attribute to message.
// Deprecated: use AddRaw.
func (m *Message) AddXORMappedAddress(ip net.IP, port int) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
family := FamilyIPv6
if ipV4 := ip.To4(); ipV4 != nil {
ip = ipV4
family = FamilyIPv4
}
value := make([]byte, 32+128)
value[0] = 0 // first 8 bits are zeroes
xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:])
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
port ^= magicCookie >> 16
binary.BigEndian.PutUint16(value[0:2], uint16(family))
binary.BigEndian.PutUint16(value[2:4], uint16(port))
xorBytes(value[4:4+len(ip)], ip, xorValue)
m.AddRaw(AttrXORMappedAddress, value[:4+len(ip)])
}
func (m *Message) allocBuffer(size int) []byte {
capacity := len(m.Raw) + size
m.grow(capacity)
m.Raw = m.Raw[:capacity]
return m.Raw[len(m.Raw)-size:]
}
// GetXORMappedAddress returns ip, port from attribute and error if any.
// Value for ip is valid until Message is released or underlying buffer is
// corrupted. Returns *DecodeError or ErrAttributeNotFound.
// Deprecated: use GetRaw.
func (m *Message) GetXORMappedAddress() (net.IP, int, error) {
// X-Port is computed by taking the mapped port in host byte order,
// XORing it with the most significant 16 bits of the magic cookie, and
// then the converting the result to network byte order.
v, err := m.getAttrValue(AttrXORMappedAddress)
if err != nil {
return nil, 0, err
}
family := byte(binary.BigEndian.Uint16(v[0:2]))
if family != FamilyIPv6 && family != FamilyIPv4 {
return nil, 0, newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == FamilyIPv6 {
ipLen = net.IPv6len
}
ip := net.IP(m.allocBuffer(ipLen))
port := int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
xorValue := make([]byte, 128)
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
copy(xorValue[4:], m.TransactionID[:])
xorBytes(ip, v[4:], xorValue)
return ip, port, nil
}
// constants for ERROR-CODE encoding.
const (
errorCodeReasonStart = 4
errorCodeClassByte = 2
errorCodeNumberByte = 3
errorCodeReasonMaxB = 763
errorCodeModulo = 100
)
// AddErrorCode adds ERROR-CODE attribute to message.
//
// The reason phrase MUST be a UTF-8 [RFC 3629] encoded
// sequence of less than 128 characters (which can be as long as 763
// bytes).
// Deprecated: use AddRaw.
func (m *Message) AddErrorCode(code int, reason string) {
value := make([]byte,
errorCodeReasonStart, errorCodeReasonMaxB+errorCodeReasonStart,
)
number := byte(code % errorCodeModulo) // error code modulo 100
class := byte(code / errorCodeModulo) // hundred digit
value[errorCodeClassByte] = class
value[errorCodeNumberByte] = number
value = append(value, reason...)
m.AddRaw(AttrErrorCode, value)
}
// AddErrorCodeDefault is wrapper for AddErrorCode that uses recommended
// reason string from RFC. If error code is unknown, reason will be "Unknown
// Error".
// Deprecated: use AddRaw.
func (m *Message) AddErrorCodeDefault(code int) {
m.AddErrorCode(code, ErrorCode(code).Reason())
}
// GetErrorCode returns ERROR-CODE code, reason and decode error if any.
// Deprecated: use GetRaw.
func (m *Message) GetErrorCode() (int, []byte, error) {
v, err := m.getAttrValue(AttrErrorCode)
if err != nil {
return 0, nil, err
}
var (
class = uint16(v[errorCodeClassByte])
number = uint16(v[errorCodeNumberByte])
code = int(class*errorCodeModulo + number)
reason = v[errorCodeReasonStart:]
)
return code, reason, nil
}
const (
// ErrAttributeNotFound means that there is no such attribute.
ErrAttributeNotFound Error = "Attribute not found"
// ErrBadSetLength means that previous attribute value length differs from
// new value.
ErrBadSetLength Error = "Previous attribute length is different"
)
// Software is SOFTWARE attribute.
type Software struct {
Raw []byte
}
func (s Software) String() string {
return string(s.Raw)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) *Software {
return &Software{Raw: []byte(software)}
}
// Encode implements AttrEncoder.
func (s *Software) Encode(b []byte, m *Message) (AttrType, []byte, error) {
return AttrSoftware, append(b, s.Raw...), nil
}
// Decode implements AttrDecoder.
func (s *Software) Decode(v []byte, m *Message) error {
s.Raw = v
return nil
}
const (
fingerprintXORValue uint32 = 0x5354554e
)
// Fingerprint represents FINGERPRINT attribute.
type Fingerprint struct {
Value uint32 // CRC-32 of message XOR-ed with 0x5354554e
}
const (
fingerprintSize = 4 // 32 bit
)
// AddTo adds fingerprint to message.
func (f *Fingerprint) AddTo(m *Message) error {
l := m.Length
// length in header should include size of fingerprint attribute
m.Length += fingerprintSize + attributeHeaderSize // increasing length
m.WriteLength() // writing Length to Raw
b := make([]byte, fingerprintSize)
f.Value = crc32.ChecksumIEEE(m.Raw) ^ fingerprintXORValue // XOR
bin.PutUint32(b, f.Value)
m.Length = l
m.AddRaw(AttrFingerprint, b)
return nil
}
// Check reads fingerprint value from m and checks it, returning error if any.
// Can return *DecodeErr, ErrAttributeNotFound, ErrCRCMissmatch.
func (f *Fingerprint) Check(m *Message) error {
v, err := m.getAttrValue(AttrFingerprint)
if err != nil {
return err
}
if len(v) != fingerprintSize {
return newDecodeErr("message", "fingerprint", "bad length")
}
f.Value = bin.Uint32(v)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
expected := crc32.ChecksumIEEE(m.Raw[:attrStart]) ^ fingerprintXORValue
if expected != f.Value {
return ErrCRCMissmatch
}
return nil
}
// ErrCRCMissmatch means that calculated fingerprint attribute differs from
// expected one.
const ErrCRCMissmatch Error = "CRC32 missmatch: bad fingerprint value"

View File

@@ -10,21 +10,24 @@ import (
"testing"
)
func TestMessage_AddSoftware(t *testing.T) {
func TestSoftware_GetFrom(t *testing.T) {
m := New()
v := "Client v0.0.1"
m.AddRaw(AttrSoftware, []byte(v))
m.Add(AttrSoftware, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
vRead := m.GetSoftware()
if vRead != v {
t.Errorf("Expected %s, got %s.", v, vRead)
if err := software.GetFrom(m); err != nil {
t.Fatal(err)
}
if software.String() != v {
t.Errorf("Expected %q, got %q.", v, software)
}
sAttr, ok := m.Attributes.Get(AttrSoftware)
@@ -37,29 +40,18 @@ func TestMessage_AddSoftware(t *testing.T) {
}
}
func TestMessage_GetSoftware(t *testing.T) {
m := New()
v := m.GetSoftware()
if v != "" {
t.Errorf("%s should be blank.", v)
}
vByte := m.GetSoftwareBytes()
if vByte != nil {
t.Errorf("%s should be nil.", vByte)
}
}
func BenchmarkMessage_AddXORMappedAddress(b *testing.B) {
func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
m := New()
b.ReportAllocs()
ip := net.ParseIP("192.168.1.32")
for i := 0; i < b.N; i++ {
m.AddXORMappedAddress(ip, 3654)
addr := &XORMappedAddress{IP: ip, Port: 3654}
addr.AddTo(m)
m.Reset()
}
}
func BenchmarkMessage_GetXORMappedAddress(b *testing.B) {
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
@@ -70,10 +62,13 @@ func BenchmarkMessage_GetXORMappedAddress(b *testing.B) {
if err != nil {
b.Error(err)
}
m.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m.AddRaw(AttrXORMappedAddress, addrValue)
m.GetXORMappedAddress()
m.Reset()
if err := addr.GetFrom(m); err != nil {
b.Fatal(err)
}
}
}
@@ -88,16 +83,16 @@ func TestMessage_GetXORMappedAddress(t *testing.T) {
if err != nil {
t.Error(err)
}
m.AddRaw(AttrXORMappedAddress, addrValue)
ip, port, err := m.GetXORMappedAddress()
if err != nil {
m.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err != nil {
t.Error(err)
}
if !ip.Equal(net.ParseIP("213.141.156.236")) {
t.Error("bad ip", ip, "!=", "213.141.156.236")
if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
}
if port != 48583 {
t.Error("bad port", port, "!=", 48583)
if addr.Port != 48583 {
t.Error("bad Port", addr.Port, "!=", 48583)
}
}
@@ -110,13 +105,15 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
addr := new(XORMappedAddress)
_, _, err = m.GetXORMappedAddress()
if err == nil {
if err = addr.GetFrom(m); err == nil {
t.Fatal(err, "should be nil")
}
m.AddXORMappedAddress(expectedIP, expectedPort)
addr.IP = expectedIP
addr.Port = expectedPort
addr.AddTo(m)
m.WriteHeader()
mRes := New()
@@ -124,8 +121,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
t.Fatal(err)
}
_, _, err = m.GetXORMappedAddress()
if err == nil {
if err = addr.GetFrom(m); err == nil {
t.Fatal(err, "should not be nil")
}
}
@@ -139,22 +135,26 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
m.AddXORMappedAddress(expectedIP, expectedPort)
addr := &XORMappedAddress{
IP: net.ParseIP("213.141.156.236"),
Port: expectedPort,
}
if err = addr.AddTo(m); err != nil {
t.Fatal(err)
}
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
if _, err = mRes.Write(m.Raw); err != nil {
t.Fatal(err)
}
ip, port, err := m.GetXORMappedAddress()
if err != nil {
if err = addr.GetFrom(mRes); err != nil {
t.Fatal(err)
}
if !ip.Equal(expectedIP) {
t.Error("bad ip", ip, "!=", expectedIP)
if !addr.IP.Equal(expectedIP) {
t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
}
if port != expectedPort {
t.Error("bad port", port, "!=", expectedPort)
if addr.Port != expectedPort {
t.Error("bad Port", addr.Port, "!=", expectedPort)
}
}
@@ -167,30 +167,47 @@ func TestMessage_AddXORMappedAddressV6(t *testing.T) {
copy(m.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254
m.AddXORMappedAddress(expectedIP, expectedPort)
addr := &XORMappedAddress{
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
Port: 21254,
}
addr.AddTo(m)
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
ip, port, err := m.GetXORMappedAddress()
if err != nil {
gotAddr := new(XORMappedAddress)
if err = gotAddr.GetFrom(m); err != nil {
t.Fatal(err)
}
if !ip.Equal(expectedIP) {
t.Error("bad ip", ip, "!=", expectedIP)
if !gotAddr.IP.Equal(expectedIP) {
t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
}
if port != expectedPort {
t.Error("bad port", port, "!=", expectedPort)
if gotAddr.Port != expectedPort {
t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
}
}
func BenchmarkMessage_AddErrorCode(b *testing.B) {
func BenchmarkErrorCode_AddTo(b *testing.B) {
m := New()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m.AddErrorCode(404, "Not found")
CodeStaleNonce.AddTo(m)
m.Reset()
}
}
func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) {
m := New()
b.ReportAllocs()
a := &ErrorCodeAttribute{
Code: 404,
Reason: []byte("not found!"),
}
for i := 0; i < b.N; i++ {
a.AddTo(m)
m.Reset()
}
}
@@ -202,51 +219,27 @@ func TestMessage_AddErrorCode(t *testing.T) {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
expectedCode := 404
expectedReason := "Not found"
m.AddErrorCode(expectedCode, expectedReason)
expectedCode := ErrorCode(428)
expectedReason := "Stale Nonce"
CodeStaleNonce.AddTo(m)
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
code, reason, err := mRes.GetErrorCode()
errCodeAttr := new(ErrorCodeAttribute)
if err = errCodeAttr.GetFrom(mRes); err != nil {
t.Error(err)
}
code := errCodeAttr.Code
if err != nil {
t.Error(err)
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(reason) != expectedReason {
t.Error("bad reason", string(reason))
}
}
func TestMessage_AddErrorCodeDefault(t *testing.T) {
m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
expectedCode := 500
expectedReason := "Server Error"
m.AddErrorCodeDefault(expectedCode)
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
code, reason, err := mRes.GetErrorCode()
if err != nil {
t.Error(err)
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(reason) != expectedReason {
t.Error("bad reason", string(reason))
if string(errCodeAttr.Reason) != expectedReason {
t.Error("bad reason", string(errCodeAttr.Reason))
}
}

View File

@@ -15,6 +15,12 @@ type DecodeErr struct {
Message string
}
// IsInvalidCookie returns true if error means that magic cookie
// value is invalid.
func (e DecodeErr) IsInvalidCookie() bool {
return e.Place == DecodeErrPlace{"message", "cookie"}
}
// IsPlaceParent reports if error place parent is p.
func (e DecodeErr) IsPlaceParent(p string) bool {
return e.Place.Parent == p
@@ -50,3 +56,8 @@ func newDecodeErr(parent, children, message string) *DecodeErr {
Message: message,
}
}
// TODO(ar): rewrite errors to be more precise.
func newAttrDecodeErr(children, message string) *DecodeErr {
return newDecodeErr("attribute", children, message)
}

View File

@@ -12,12 +12,12 @@ func FuzzMessage(data []byte) int {
// fuzzer dont know about cookies
binary.BigEndian.PutUint32(data[4:8], magicCookie)
// trying to read data as message
if _, err := m.ReadBytes(data); err != nil {
if _, err := m.Write(data); err != nil {
return 0
}
m.WriteHeader()
m2 := New()
if _, err := m2.ReadBytes(m2.Raw); err != nil {
if _, err := m2.Write(m2.Raw); err != nil {
panic(err)
}
if m2.TransactionID != m.TransactionID {

491
message.go Normal file
View File

@@ -0,0 +1,491 @@
// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389.
//
// Definitions
//
// STUN Agent: A STUN agent is an entity that implements the STUN
// protocol. The entity can be either a STUN client or a STUN
// server.
//
// STUN Client: A STUN client is an entity that sends STUN requests and
// receives STUN responses. A STUN client can also send indications.
// In this specification, the terms STUN client and client are
// synonymous.
//
// STUN Server: A STUN server is an entity that receives STUN requests
// and sends STUN responses. A STUN server can also send
// indications. In this specification, the terms STUN server and
// server are synonymous.
//
// Transport Address: The combination of an IP address and Port number
// (such as a UDP or TCP Port number).
package stun
import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"strconv"
)
const (
// magicCookie is fixed value that aids in distinguishing STUN packets
// from packets of other protocols when STUN is multiplexed with those
// other protocols on the same Port.
//
// The magic cookie field MUST contain the fixed value 0x2112A442 in
// network byte order.
//
// Defined in "STUN Message Structure", section 6.
magicCookie = 0x2112A442
attributeHeaderSize = 4
messageHeaderSize = 20
transactionIDSize = 12 // 96 bit
)
// NewTransactionID returns new random transaction ID using crypto/rand
// as source.
func NewTransactionID() (b [transactionIDSize]byte) {
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return b
}
// IsMessage returns true if b looks like STUN message.
// Useful for multiplexing. IsMessage does not guarantee
// that decoding will be successful.
func IsMessage(b []byte) bool {
return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
}
// New returns *Message with pre-allocated Raw.
func New() *Message {
const defaultRawCapacity = 120
return &Message{
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
}
}
// Message represents a single STUN packet. It uses aggressive internal
// buffering to enable zero-allocation encoding and decoding,
// so there are some usage constraints:
//
// * Message and its fields is valid only until AcquireMessage call.
type Message struct {
Type MessageType
Length uint32 // len(Raw) not including header
TransactionID [transactionIDSize]byte
Attributes Attributes
Raw []byte
}
// NewTransactionID sets m.TransactionID to random value from crypto/rand
// and returns error if any.
func (m *Message) NewTransactionID() error {
_, err := rand.Read(m.TransactionID[:])
return err
}
func (m Message) String() string {
return fmt.Sprintf("%s l=%d attrs=%d id=%s",
m.Type,
m.Length,
len(m.Attributes),
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
)
}
// Reset resets Message, attributes and underlying buffer length.
func (m *Message) Reset() {
m.Raw = m.Raw[:0]
m.Length = 0
m.Attributes = m.Attributes[:0]
}
// grow ensures that internal buffer will fit v more bytes and
// increases it capacity if necessary.
func (m *Message) grow(v int) {
// Not performing any optimizations here
// (e.g. preallocate len(buf) * 2 to reduce allocations)
// because they are already done by []byte implementation.
n := len(m.Raw) + v
for cap(m.Raw) < n {
m.Raw = append(m.Raw, 0)
}
m.Raw = m.Raw[:n]
}
// Add appends new attribute to message. Not goroutine-safe.
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func (m *Message) Add(t AttrType, v []byte) {
// Allocating buffer for TLV (type-length-value).
// T = t, L = len(v), V = v.
// m.Raw will look like:
// [0:20] <- message header
// [20:20+m.Length] <- existing message attributes
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
// [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
// T L V
allocSize := attributeHeaderSize + len(v) // len(TLV) = len(TL) + len(V)
first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number
m.grow(last) // growing cap(Raw) to fit TLV
m.Raw = m.Raw[:last] // now len(Raw) = last
m.Length += uint32(allocSize) // rendering length change
// Sub-slicing internal buffer to simplify encoding.
buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for V
attr := RawAttribute{
Type: t, // T
Length: uint16(len(v)), // L
Value: value, // V
}
// Encoding attribute TLV to allocated buffer.
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
bin.PutUint16(buf[2:4], attr.Length) // L
copy(value, v) // V
// Checking that attribute value needs padding.
if attr.Length%padding != 0 {
// Performing padding.
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
last += bytesToAdd
m.grow(last)
// setting all padding bytes to zero
// to prevent data leak from previous
// data in next bytesToAdd bytes
buf = m.Raw[last-bytesToAdd : last]
for i := range buf {
buf[i] = 0
}
m.Raw = m.Raw[:last] // increasing buffer length
m.Length += uint32(bytesToAdd) // rendering length change
}
m.Attributes = append(m.Attributes, attr)
}
// Equal returns true if Message b equals to m.
// Ignores m.Raw.
func (m *Message) Equal(b *Message) bool {
if m.Type != b.Type {
return false
}
if m.TransactionID != b.TransactionID {
return false
}
if m.Length != b.Length {
return false
}
for _, a := range m.Attributes {
aB, ok := b.Attributes.Get(a.Type)
if !ok {
return false
}
if !aB.Equal(a) {
return false
}
}
return true
}
// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4.
func (m *Message) WriteLength() {
_ = m.Raw[4] // early bounds check to guarantee safety of writes below
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
}
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
func (m *Message) WriteHeader() {
if len(m.Raw) < messageHeaderSize {
// Making WriteHeader call valid even when m.Raw
// is nil or len(m.Raw) is less than needed for header.
m.grow(messageHeaderSize)
}
_ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize)) // size of payload
bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
}
// WriteAttributes encodes all m.Attributes to m.
func (m *Message) WriteAttributes() {
for _, a := range m.Attributes {
m.Add(a.Type, a.Value)
}
}
// Encode resets m.Raw and calls WriteHeader and WriteAttributes.
func (m *Message) Encode() {
m.Raw = m.Raw[:0]
m.WriteHeader()
m.WriteAttributes()
}
// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
// call result.
func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw)
return int64(n), err
}
// Append appends m.Raw to v. Useful to call after encoding message.
func (m *Message) Append(v []byte) []byte {
return append(v, m.Raw...)
}
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
//
// Can return *DecodeErr while decoding too.
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
tBuf := m.Raw[:cap(m.Raw)]
var (
n int
err error
)
if n, err = r.Read(tBuf); err != nil {
return int64(n), err
}
m.Raw = tBuf[:n]
return int64(n), m.Decode()
}
const (
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header.
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
)
// Decode decodes m.Raw into m.
func (m *Message) Decode() error {
// decoding message header
buf := m.Raw
if len(buf) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
cookie = binary.BigEndian.Uint32(buf[4:8])
fullSize = messageHeaderSize + size
)
if cookie != magicCookie {
msg := fmt.Sprintf(
"%x is invalid magic cookie (should be %x)",
cookie, magicCookie,
)
return newDecodeErr("message", "cookie", msg)
}
if len(buf) < fullSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected message size)",
len(buf), fullSize,
)
return newAttrDecodeErr("message", msg)
}
// saving header data
m.Type.ReadValue(t)
m.Length = uint32(size)
copy(m.TransactionID[:], buf[8:messageHeaderSize])
var (
offset = 0
b = buf[messageHeaderSize:fullSize]
)
for offset < size {
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected header size)",
len(b), attributeHeaderSize,
)
return newAttrDecodeErr("header", msg)
}
var (
a = RawAttribute{
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes
}
aL = int(a.Length) // attribute length
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
)
b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize
if len(b) < aBuffL { // checking size
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected value size)",
len(b), aBuffL,
)
return newAttrDecodeErr("value", msg)
}
a.Value = b[:aL]
offset += aBuffL
b = b[aBuffL:]
m.Attributes = append(m.Attributes, a)
}
return nil
}
// Write decodes message and return error if any.
//
// Any error is unrecoverable, but message could be partially decoded.
func (m *Message) Write(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode()
}
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
const MaxPacketSize = 2048
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte
// Possible values for message class in STUN Message Type.
const (
ClassRequest MessageClass = 0x00 // 0b00
ClassIndication MessageClass = 0x01 // 0b01
ClassSuccessResponse MessageClass = 0x02 // 0b10
ClassErrorResponse MessageClass = 0x03 // 0b11
)
func (c MessageClass) String() string {
switch c {
case ClassRequest:
return "request"
case ClassIndication:
return "indication"
case ClassSuccessResponse:
return "success response"
case ClassErrorResponse:
return "error response"
default:
panic("unknown message class")
}
}
// Method is uint16 representation of 12-bit STUN method.
type Method uint16
// Possible methods for STUN Message.
const (
MethodBinding Method = 0x001
MethodAllocate Method = 0x003
MethodRefresh Method = 0x004
MethodSend Method = 0x006
MethodData Method = 0x007
MethodCreatePermission Method = 0x008
MethodChannelBind Method = 0x009
)
func (m Method) String() string {
switch m {
case MethodBinding:
return "binding"
case MethodAllocate:
return "allocate"
case MethodRefresh:
return "refresh"
case MethodSend:
return "send"
case MethodData:
return "data"
case MethodCreatePermission:
return "create permission"
case MethodChannelBind:
return "channel bind"
default:
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
}
}
// MessageType is STUN Message Type Field.
type MessageType struct {
Class MessageClass
Method Method
}
const (
methodABits = 0xf // 0b0000000000001111
methodBBits = 0x70 // 0b0000000001110000
methodDBits = 0xf80 // 0b0000111110000000
methodBShift = 1
methodDShift = 2
firstBit = 0x1
secondBit = 0x2
c0Bit = firstBit
c1Bit = secondBit
classC0Shift = 4
classC1Shift = 7
)
// Value returns bit representation of messageType.
func (t MessageType) Value() uint16 {
// 0 1
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// Figure 3: Format of STUN Message Type Field
// Warning: Abandon all hope ye who enter here.
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
m := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
m = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is fist bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
// Ct = C0 << 4 + C1 << 8.
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
// (see figure 3).
c := uint16(t.Class)
c0 := (c & c0Bit) << classC0Shift
c1 := (c & c1Bit) << classC1Shift
class := c0 + c1
return m + class
}
// ReadValue decodes uint16 into MessageType.
func (t *MessageType) ReadValue(v uint16) {
// Decoding class.
// We are taking first bit from v >> 4 and second from v >> 7.
c0 := (v >> classC0Shift) & c0Bit
c1 := (v >> classC1Shift) & c1Bit
class := c0 + c1
t.Class = MessageClass(class)
// Decoding method.
a := v & methodABits // A(M0-M3)
b := (v >> methodBShift) & methodBBits // B(M4-M6)
d := (v >> methodDShift) & methodDBits // D(M7-M11)
m := a + b + d
t.Method = Method(m)
}
func (t MessageType) String() string {
return fmt.Sprintf("%s %s", t.Method, t.Class)
}

13
padding.go Normal file
View File

@@ -0,0 +1,13 @@
package stun
const (
padding = 4
)
func nearestPaddedValueLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}

View File

@@ -6,7 +6,7 @@ import (
"log"
"net"
"net/http"
"runtime"
"sync/atomic"
"time"
"github.com/ernado/stun"
@@ -17,13 +17,14 @@ var (
fmt.Sprintf("127.0.0.1:%d", stun.DefaultPort),
"addr to attack",
)
readWorkers = flag.Int("read-workers", 1, "concurrent read workers")
writeWorkers = flag.Int("write-workers", 1, "concurrent write workers")
count int64
)
func main() {
flag.Parse()
runtime.GOMAXPROCS(2)
log.SetFlags(log.Lshortfile)
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
@@ -43,8 +44,9 @@ func main() {
},
TransactionID: stun.NewTransactionID(),
}
m.AddRaw(stun.AttrSoftware, []byte("stun benchmark"))
m.Add(stun.AttrSoftware, []byte("stun benchmark"))
m.Encode()
for i := 0; i < *readWorkers; i++ {
go func() {
mRec := stun.New()
mRec.Raw = make([]byte, 1024)
@@ -58,15 +60,26 @@ func main() {
// if err := mRec.Decode(); err != nil {
// log.Fatalln("Decode:", err)
// }
count++
if count%10000 == 0 {
fmt.Printf("%d\n", count)
atomic.AddInt64(&count, 1)
if c := atomic.LoadInt64(&count); c%10000 == 0 {
fmt.Printf("%d\n", c)
elapsed := time.Since(start)
fmt.Println(float64(count)/elapsed.Seconds(), "per second")
fmt.Println(float64(c)/elapsed.Seconds(), "per second")
}
// mRec.Reset()
}
}()
}
for i := 1; i < *writeWorkers; i++ {
go func() {
for {
_, err := c.Write(m.Raw)
if err != nil {
log.Fatalln("write:", err)
}
}
}()
}
for {
_, err := c.Write(m.Raw)
if err != nil {

View File

@@ -33,7 +33,7 @@ func TestClient_Do(t *testing.T) {
m := stun.New()
m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
m.TransactionID = stun.NewTransactionID()
m.Add(stun.NewSoftware("cydev/stun alpha"))
stun.NewSoftware("cydev/stun alpha").AddTo(m)
m.WriteHeader()
request := Request{
Target: "stun.l.google.com:19302",
@@ -43,11 +43,11 @@ func TestClient_Do(t *testing.T) {
if r.Message.TransactionID != m.TransactionID {
t.Error("transaction id messmatch")
}
ip, port, err := r.Message.GetXORMappedAddress()
if err != nil {
addr := new(stun.XORMappedAddress)
if err := addr.GetFrom(m); err != nil {
t.Error(err)
}
log.Println("got", ip, port)
log.Println("got", addr)
return nil
}); err != nil {
t.Fatal(err)

View File

@@ -190,7 +190,7 @@ func discover(c *cli.Context) error {
Class: stun.ClassRequest,
},
}
m.AddRaw(stun.AttrSoftware, software.Raw)
m.Add(stun.AttrSoftware, software.Raw)
m.WriteHeader()
request := Request{
@@ -200,15 +200,13 @@ func discover(c *cli.Context) error {
return DefaultClient.Do(request, func(r Response) error {
var (
ip net.IP
port int
err error
)
ip, port, err = r.Message.GetXORMappedAddress()
if err != nil {
addr := new(stun.XORMappedAddress)
if err = addr.GetFrom(r.Message); err != nil {
return errors.Wrap(err, "failed to get ip")
}
fmt.Println(ip, port)
fmt.Println(addr)
return nil
})
}

View File

@@ -23,7 +23,7 @@ func main() {
log.Fatalln("Unable to decode bas64 value:", err)
}
m := stun.New()
if _, err = m.ReadBytes(data); err != nil {
if _, err = m.Write(data); err != nil {
log.Fatalln("Unable to decode message:", err)
}
fmt.Println(m)

540
stun.go
View File

@@ -16,542 +16,14 @@
// indications. In this specification, the terms STUN server and
// server are synonymous.
//
// Transport Address: The combination of an IP address and port number
// (such as a UDP or TCP port number).
// Transport Address: The combination of an IP address and Port number
// (such as a UDP or TCP Port number).
package stun
import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
)
import "encoding/binary"
var (
// bin is shorthand to binary.BigEndian.
bin = binary.BigEndian
)
// bin is shorthand to binary.BigEndian.
var bin = binary.BigEndian
// DefaultPort is IANA assigned port for "stun" protocol.
// DefaultPort is IANA assigned Port for "stun" protocol.
const DefaultPort = 3478
const (
// magicCookie is fixed value that aids in distinguishing STUN packets
// from packets of other protocols when STUN is multiplexed with those
// other protocols on the same port.
//
// The magic cookie field MUST contain the fixed value 0x2112A442 in
// network byte order.
//
// Defined in "STUN Message Structure", section 6.
magicCookie = 0x2112A442
)
const transactionIDSize = 12 // 96 bit
// Message represents a single STUN packet. It uses aggressive internal
// buffering to enable zero-allocation encoding and decoding,
// so there are some usage constraints:
//
// * Message and its fields is valid only until AcquireMessage call.
type Message struct {
Type MessageType
Length uint32 // len(Raw) not including header
TransactionID [transactionIDSize]byte
Attributes Attributes
Raw []byte
}
// CopyTo copies all m to c.
func (m Message) CopyTo(c *Message) {
c.Type = m.Type
c.Length = m.Length
copy(c.TransactionID[:], m.TransactionID[:])
buf := m.Raw[:int(m.Length)+messageHeaderSize]
c.Raw = c.Raw[:0]
c.Raw = append(c.Raw, buf...)
buf = c.Raw[messageHeaderSize:]
for _, a := range m.Attributes {
buf = buf[attributeHeaderSize:]
c.Attributes = append(c.Attributes, RawAttribute{
Length: a.Length,
Type: a.Type,
Value: buf[:int(a.Length)],
})
buf = buf[int(a.Length):]
}
}
func (m Message) String() string {
return fmt.Sprintf("%s (l=%d,%d/%d) attr[%d] id[%s]",
m.Type,
m.Length,
len(m.Raw),
cap(m.Raw),
len(m.Attributes),
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
)
}
// NewTransactionID returns new random transaction ID using crypto/rand
// as source.
func NewTransactionID() (b [transactionIDSize]byte) {
_, err := rand.Read(b[:])
if err != nil {
panic(err)
}
return b
}
// defaults for pool.
const (
defaultMessageBufferCapacity = 120
)
// New returns *Message with allocated Raw.
func New() *Message {
return &Message{
Raw: make([]byte, messageHeaderSize, defaultMessageBufferCapacity),
}
}
// Reset resets Message length, attributes and underlying buffer, as well as
// setting readOnly flag to false.
func (m *Message) Reset() {
m.Raw = m.Raw[:0]
m.Length = 0
m.Attributes = m.Attributes[:0]
}
// grow ensures that internal buffer will fit v more bytes and
// increases it capacity if necessary.
func (m *Message) grow(v int) {
// Not performing any optimizations here
// (e.g. preallocate len(buf) * 2 to reduce allocations)
// because they are already done by []byte implementation.
n := len(m.Raw) + v
for cap(m.Raw) < n {
m.Raw = append(m.Raw, 0)
}
m.Raw = m.Raw[:n]
}
// Add adds AttrEncoder to message, calling Encode method.
func (m *Message) Add(a AttrEncoder) error {
var (
err error
t AttrType
)
initial := len(m.Raw)
for i := 0; i < attributeHeaderSize; i++ {
m.Raw = append(m.Raw, 0)
}
start := len(m.Raw)
t, m.Raw, err = a.Encode(m.Raw, m)
if err != nil {
m.Raw = m.Raw[:initial]
return err
}
m.AddRaw(t, m.Raw[start:])
return nil
}
// AddRaw appends new attribute to message. Not goroutine-safe.
//
// Value of attribute is copied to internal buffer so
// it is safe to reuse v.
func (m *Message) AddRaw(t AttrType, v []byte) {
// CPU: suboptimal;
// allocating memory for TLV (type-length-value), where
// type-length is attribute header.
// m.buf.B[0:20] is reserved by header
// internal buffer will look like:
// [0:20] <- message header
// [20:20+m.Length] <- added message attributes
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated slice for new TLV
// [first:last] <- same as previous
// [0 1|2 3|4 4 + len(v)]
// T L V
allocSize := attributeHeaderSize + len(v) // len(TLV)
first := messageHeaderSize + int(m.Length) // first byte number
last := first + allocSize // last byte number
m.grow(last) // growing cap(b) to fit TLV
m.Raw = m.Raw[:last] // now len(b) = last
m.Length += uint32(allocSize) // rendering length change
// subslicing internal buffer to simplify encoding
buf := m.Raw[first:last] // slice for TLV
value := buf[attributeHeaderSize:] // slice for value
attr := RawAttribute{
Type: t,
Value: value,
Length: uint16(len(v)),
}
// encoding attribute TLV to internal buffer
bin.PutUint16(buf[0:2], attr.Type.Value()) // type
bin.PutUint16(buf[2:4], attr.Length) // length
copy(value, v) // value
if attr.Length%padding != 0 {
// performing padding
bytesToAdd := nearestLength(len(v)) - len(v)
last += bytesToAdd
m.grow(last)
// setting all padding bytes to zero
// to prevent data leak from previous
// data in next bytesToAdd bytes
buf = m.Raw[last-bytesToAdd : last]
for i := range buf {
buf[i] = 0
}
m.Raw = m.Raw[:last] // increasing buffer length
m.Length += uint32(bytesToAdd) // rendering length change
}
m.Attributes = append(m.Attributes, attr)
}
const (
padding = 4
)
func nearestLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}
// Equal returns true if Message b equals to m.
func (m *Message) Equal(b *Message) bool {
if m.Type != b.Type {
return false
}
if m.TransactionID != b.TransactionID {
return false
}
if m.Length != b.Length {
return false
}
for _, a := range m.Attributes {
aB, ok := b.Attributes.Get(a.Type)
if !ok {
return false
}
if !aB.Equal(a) {
return false
}
}
return true
}
// WriteLength writes m.Length to m.Raw.
func (m *Message) WriteLength() { bin.PutUint16(m.Raw[2:4], uint16(m.Length)) }
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
func (m *Message) WriteHeader() {
// encoding header
if len(m.Raw) < messageHeaderSize {
m.grow(messageHeaderSize)
}
bin.PutUint16(m.Raw[0:2], m.Type.Value())
bin.PutUint32(m.Raw[4:8], magicCookie)
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:])
// attributes are already encoded
// writing length as size, in bytes, not including the 20-byte STUN header.
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize))
}
// WriteAttributes encodes all m.Attributes to m.
func (m *Message) WriteAttributes() {
for _, a := range m.Attributes {
m.AddRaw(a.Type, a.Value)
}
}
// Encode writes m into Raw.
func (m *Message) Encode() {
m.Raw = m.Raw[:0]
m.WriteHeader()
m.WriteAttributes()
}
// WriteTo implements WriterTo.
func (m *Message) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(m.Raw)
return int64(n), err
}
// WriteToConn writes a packet with message to addr, using c.
// Deprecated; non-idiomatic.
func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) {
return c.WriteTo(m.Raw, addr)
}
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
// Decodes it and return error if any. If m.Raw is too small, will return
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
//
// Can return *DecodeErr while decoding.
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
tBuf := m.Raw[:cap(m.Raw)]
var (
n int
err error
)
if n, err = r.Read(tBuf); err != nil {
return int64(n), err
}
m.Raw = tBuf[:n]
return int64(n), m.Decode()
}
func newAttrDecodeErr(children, message string) *DecodeErr {
return newDecodeErr("attribute", children, message)
}
// IsMessage returns true if b looks like STUN message.
// Useful for multiplexing. IsMessage does not guarantee
// that decoding will be successful.
func IsMessage(b []byte) bool {
// b should be at least messageHeaderSize bytes and
// contain correct magicCookie.
return len(b) >= messageHeaderSize &&
binary.BigEndian.Uint32(b[4:8]) == magicCookie
}
const (
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
// m.Raw to read header.
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
)
// Decode decodes m.Raw into m.
func (m *Message) Decode() error {
// decoding message header
buf := m.Raw
if len(buf) < messageHeaderSize {
return ErrUnexpectedHeaderEOF
}
var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
cookie = binary.BigEndian.Uint32(buf[4:8])
fullSize = messageHeaderSize + size
)
if cookie != magicCookie {
msg := fmt.Sprintf(
"%x is invalid magic cookie (should be %x)",
cookie, magicCookie,
)
return newDecodeErr("message", "cookie", msg)
}
if len(buf) < fullSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected message size)",
len(buf), fullSize,
)
return newAttrDecodeErr("message", msg)
}
// saving header data
m.Type.ReadValue(t)
m.Length = uint32(size)
copy(m.TransactionID[:], buf[8:messageHeaderSize])
var (
offset = 0
b = buf[messageHeaderSize:fullSize]
)
for offset < size {
// checking that we have enough bytes to read header
if len(b) < attributeHeaderSize {
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected header size)",
len(b), attributeHeaderSize,
)
return newAttrDecodeErr("header", msg)
}
var (
a = RawAttribute{
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
Length: bin.Uint16(b[2:4]), // second 2 bytes
}
aL = int(a.Length) // attribute length
aBuffL = nearestLength(aL) // expected buffer length (with padding)
)
b = b[attributeHeaderSize:] // slicing again to simplify value read
offset += attributeHeaderSize
if len(b) < aBuffL { // checking size
msg := fmt.Sprintf(
"buffer length %d is less than %d (expected value size)",
len(b), aBuffL,
)
return newAttrDecodeErr("value", msg)
}
a.Value = b[:aL]
offset += aBuffL
b = b[aBuffL:]
m.Attributes = append(m.Attributes, a)
}
return nil
}
// ReadBytes decodes message and return error if any.
//
// Any error is unrecoverable, but message could be partially decoded.
// Deprecated: use m.Decode.
func (m *Message) ReadBytes(tBuf []byte) (int, error) {
m.Raw = append(m.Raw[:0], tBuf...)
return len(tBuf), m.Decode()
}
const (
attributeHeaderSize = 4
messageHeaderSize = 20
)
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
const MaxPacketSize = 2048
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte
// Possible values for message class in STUN Message Type.
const (
ClassRequest MessageClass = 0x00 // 0b00
ClassIndication MessageClass = 0x01 // 0b01
ClassSuccessResponse MessageClass = 0x02 // 0b10
ClassErrorResponse MessageClass = 0x03 // 0b11
)
func (c MessageClass) String() string {
switch c {
case ClassRequest:
return "request"
case ClassIndication:
return "indication"
case ClassSuccessResponse:
return "success response"
case ClassErrorResponse:
return "error response"
default:
panic("unknown message class")
}
}
// Method is uint16 representation of 12-bit STUN method.
type Method uint16
// Possible methods for STUN Message.
const (
MethodBinding Method = 0x001
MethodAllocate Method = 0x003
MethodRefresh Method = 0x004
MethodSend Method = 0x006
MethodData Method = 0x007
MethodCreatePermission Method = 0x008
MethodChannelBind Method = 0x009
)
func (m Method) String() string {
switch m {
case MethodBinding:
return "binding"
case MethodAllocate:
return "allocate"
case MethodRefresh:
return "refresh"
case MethodSend:
return "send"
case MethodData:
return "data"
case MethodCreatePermission:
return "create permission"
case MethodChannelBind:
return "channel bind"
default:
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
}
}
// MessageType is STUN Message Type Field.
type MessageType struct {
Class MessageClass
Method Method
}
const (
methodABits = 0xf // 0b0000000000001111
methodBBits = 0x70 // 0b0000000001110000
methodDBits = 0xf80 // 0b0000111110000000
methodBShift = 1
methodDShift = 2
firstBit = 0x1
secondBit = 0x2
c0Bit = firstBit
c1Bit = secondBit
classC0Shift = 4
classC1Shift = 7
)
// Value returns bit representation of messageType.
func (t MessageType) Value() uint16 {
// 0 1
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
// Figure 3: Format of STUN Message Type Field
// warning: Abandon all hope ye who enter here.
// splitting M into A(M0-M3), B(M4-M6), D(M7-M11)
m := uint16(t.Method)
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
// shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit)
m = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is fist bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
// Ct = C0 << 4 + C1 << 8.
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
// (see figure 3).
c := uint16(t.Class)
c0 := (c & c0Bit) << classC0Shift
c1 := (c & c1Bit) << classC1Shift
class := c0 + c1
return m + class
}
// ReadValue decodes uint16 into MessageType.
func (t *MessageType) ReadValue(v uint16) {
// decoding class
// we are taking first bit from v >> 4 and second from v >> 7.
c0 := (v >> classC0Shift) & c0Bit
c1 := (v >> classC1Shift) & c1Bit
class := c0 + c1
t.Class = MessageClass(class)
// decoding method
a := v & methodABits // A(M0-M3)
b := (v >> methodBShift) & methodBBits // B(M4-M6)
d := (v >> methodDShift) & methodDBits // D(M7-M11)
m := a + b + d
t.Method = Method(m)
}
func (t MessageType) String() string {
return fmt.Sprintf("%s %s", t.Method, t.Class)
}

View File

@@ -9,14 +9,13 @@ import (
"hash/crc64"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"net"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
@@ -33,24 +32,11 @@ func (m *Message) reader() *bytes.Reader {
return bytes.NewReader(m.Raw)
}
func TestMessageCopy(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
mCopy := New()
m.CopyTo(mCopy)
if !mCopy.Equal(m) {
t.Error(mCopy, "!=", m)
}
}
func TestMessageBuffer(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
@@ -69,7 +55,7 @@ func BenchmarkMessage_Write(b *testing.B) {
transactionID := NewTransactionID()
m := New()
for i := 0; i < b.N; i++ {
m.AddRaw(AttrErrorCode, attributeValue)
m.Add(AttrErrorCode, attributeValue)
m.TransactionID = transactionID
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.WriteHeader()
@@ -133,6 +119,25 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
}
}
func TestMessage_WriteTo(t *testing.T) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
buf := new(bytes.Buffer)
if _, err := m.WriteTo(buf); err != nil {
t.Fatal(err)
}
mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err)
}
if !mDecoded.Equal(m) {
t.Error(mDecoded, "!", m)
}
}
func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20)
mDecoded := New()
@@ -156,11 +161,11 @@ func TestMessage_BadLength(t *testing.T) {
Length: 4,
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
}
m.AddRaw(0x1, []byte{1, 2})
m.Add(0x1, []byte{1, 2})
m.WriteHeader()
m.Raw[20+3] = 10 // set attr length = 10
mDecoded := New()
if _, err := mDecoded.ReadBytes(m.Raw); err == nil {
if _, err := mDecoded.Write(m.Raw); err == nil {
t.Error("should error")
}
}
@@ -295,7 +300,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
b.SetBytes(int64(len(m.Raw)))
mRec := New()
for i := 0; i < b.N; i++ {
if _, err := mRec.ReadBytes(m.Raw); err != nil {
if _, err := mRec.Write(m.Raw); err != nil {
b.Fatal(err)
}
mRec.Reset()
@@ -439,9 +444,7 @@ func TestMessage_String(t *testing.T) {
func TestIsMessage(t *testing.T) {
m := New()
m.Add(&Software{
Raw: []byte("test"),
})
NewSoftware("software").AddTo(m)
m.WriteHeader()
var tt = [...]struct {
@@ -468,7 +471,7 @@ func BenchmarkIsMessage(b *testing.B) {
m := New()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.AddSoftware("cydev/stun test")
NewSoftware("cydev/stun test").AddTo(m)
m.WriteHeader()
b.SetBytes(int64(messageHeaderSize))
@@ -502,13 +505,13 @@ func loadData(tb testing.TB, name string) []byte {
func TestExampleChrome(t *testing.T) {
buf := loadData(t, "ex1_chrome.stun")
m := New()
_, err := m.ReadBytes(buf)
_, err := m.Write(buf)
if err != nil {
t.Errorf("Failed to parse ex1_chrome: %s", err)
}
}
func TestNearestLen(t *testing.T) {
func TestPadding(t *testing.T) {
tt := []struct {
in, out int
}{
@@ -525,7 +528,7 @@ func TestNearestLen(t *testing.T) {
{40, 40}, // 10
}
for i, c := range tt {
if got := nearestLength(c.in); got != c.out {
if got := nearestPaddedValueLength(c.in); got != c.out {
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
i, c.in, got, c.out,
)
@@ -562,7 +565,7 @@ func TestMessageFromBrowsers(t *testing.T) {
if b != crc64.Checksum(data, crcTable) {
t.Error("crc64 check failed for ", line[1])
}
if _, err = m.ReadBytes(data); err != nil {
if _, err = m.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err)
}
m.Reset()
@@ -583,20 +586,29 @@ func BenchmarkNewTransactionID(b *testing.B) {
}
}
func BenchmarkMessage_NewTransactionID(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
if err := m.NewTransactionID(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkMessageFull(b *testing.B) {
b.ReportAllocs()
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
IP: net.IPv4(213, 1, 223, 5),
}
fingerprint := new(Fingerprint)
for i := 0; i < b.N; i++ {
m.Add(addr)
m.Add(s)
addAttr(b, m, addr)
addAttr(b, m, s)
m.WriteAttributes()
m.WriteHeader()
fingerprint.AddTo(m)
Fingerprint.AddTo(m)
m.WriteHeader()
m.Reset()
}
@@ -607,19 +619,26 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
IP: net.IPv4(213, 1, 223, 5),
}
t, v, _ := addr.Encode(nil, m)
for i := 0; i < b.N; i++ {
m.AddRaw(AttrSoftware, s.Raw)
m.AddRaw(t, v)
m.WriteAttributes()
if err := addr.AddTo(m); err != nil {
b.Fatal(err)
}
if err := s.AddTo(m); err != nil {
b.Fatal(err)
}
m.WriteHeader()
m.Reset()
}
}
func addAttr(t testing.TB, m *Message, a AttrEncoder) {
if err := m.Add(a); err != nil {
type attributeEncoder interface {
AddTo(m *Message) error
}
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
if err := a.AddTo(m); err != nil {
t.Error(err)
}
}
@@ -629,14 +648,13 @@ func BenchmarkFingerprint_AddTo(b *testing.B) {
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
IP: net.IPv4(213, 1, 223, 5),
}
addAttr(b, m, addr)
addAttr(b, m, s)
fingerprint := new(Fingerprint)
b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ {
fingerprint.AddTo(m)
Fingerprint.AddTo(m)
m.WriteLength()
m.Length -= attributeHeaderSize + fingerprintSize
m.Raw = m.Raw[:m.Length+messageHeaderSize]
@@ -648,12 +666,10 @@ func TestFingerprint_Check(t *testing.T) {
m := new(Message)
addAttr(t, m, NewSoftware("software"))
m.WriteHeader()
fInit := new(Fingerprint)
fInit.AddTo(m)
Fingerprint.AddTo(m)
m.WriteHeader()
f := new(Fingerprint)
if err := f.Check(m); err != nil {
t.Errorf("decoded: %d, encoded: %d, err: %s", f.Value, fInit.Value, err)
if err := Fingerprint.Check(m); err != nil {
t.Error(err)
}
}
@@ -662,17 +678,16 @@ func BenchmarkFingerprint_Check(b *testing.B) {
m := new(Message)
s := NewSoftware("software")
addr := &XORMappedAddress{
ip: net.IPv4(213, 1, 223, 5),
IP: net.IPv4(213, 1, 223, 5),
}
addAttr(b, m, addr)
addAttr(b, m, s)
fingerprint := new(Fingerprint)
m.WriteHeader()
fingerprint.AddTo(m)
Fingerprint.AddTo(m)
m.WriteHeader()
b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ {
if err := fingerprint.Check(m); err != nil {
if err := Fingerprint.Check(m); err != nil {
b.Fatal(err)
}
}

View File

@@ -5,9 +5,11 @@ import (
"fmt"
"log"
"net"
"runtime"
"net/http"
"strings"
_ "net/http/pprof"
"github.com/ernado/stun"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -16,6 +18,8 @@ import (
var (
network = flag.String("net", "udp", "network to listen")
address = flag.String("addr", "0.0.0.0:3478", "address to listen")
workers = flag.Int("workers", 1, "workers to start")
profile = flag.Bool("profile", false, "profile")
)
// Server is RFC 5389 basic server implementation.
@@ -80,7 +84,7 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
if !stun.IsMessage(b) {
return errNotSTUNMessage
}
if _, err := req.ReadBytes(b); err != nil {
if _, err := req.Write(b); err != nil {
return errors.Wrap(err, "failed to read message")
}
res.TransactionID = req.TransactionID
@@ -99,8 +103,12 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
default:
panic(fmt.Sprintf("unknown addr: %v", addr))
}
res.AddXORMappedAddress(ip, port)
res.AddRaw(stun.AttrSoftware, software.Raw)
xma := &stun.XORMappedAddress{
IP: ip,
Port: port,
}
xma.AddTo(res)
software.AddTo(res)
res.WriteHeader()
return nil
}
@@ -116,8 +124,8 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
return nil
}
// s.logger().Printf("read %d bytes from %s", n, addr)
if _, err = req.ReadBytes(buf[:n]); err != nil {
s.logger().Printf("ReadBytes: %v", err)
if _, err = req.Write(buf[:n]); err != nil {
s.logger().Printf("Write: %v", err)
return err
}
if err = basicProcess(addr, buf[:n], req, res); err != nil {
@@ -156,8 +164,10 @@ func ListenUDPAndServe(serverNet, laddr string) error {
if err != nil {
return err
}
s := &Server{}
return s.Serve(c)
for i := 1; i < *workers; i++ {
go new(Server).Serve(c)
}
return new(Server).Serve(c)
}
func normalize(address string) string {
@@ -172,7 +182,11 @@ func normalize(address string) string {
func main() {
flag.Parse()
runtime.GOMAXPROCS(1)
if *profile {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
}
switch *network {
case "udp":
normalized := normalize(*address)

View File

@@ -17,7 +17,7 @@ func BenchmarkBasicProcess(b *testing.B) {
b.Fatal(err)
}
m.TransactionID = stun.NewTransactionID()
m.Add(stun.NewSoftware("some software"))
stun.NewSoftware("some software").AddTo(m)
m.WriteHeader()
b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ {

2
xor.go
View File

@@ -34,7 +34,7 @@ func fastXORBytes(dst, a, b []byte) int {
}
}
for i := (n - n%wordSize); i < n; i++ {
for i := n - n%wordSize; i < n; i++ {
dst[i] = a[i] ^ b[i]
}