mirror of
https://github.com/pion/stun.git
synced 2025-10-04 15:32:46 +08:00
all: refactor attributes
This commit is contained in:
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2016 Aleksandr Razumov, Cydev. All Rigths Reserved.
|
||||
Copyright (c) 2016-2017 Aleksandr Razumov, Cydev. All Rigths Reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
@@ -2,39 +2,88 @@ package stun
|
||||
|
||||
// ErrorCodeAttribute represents ERROR-CODE attribute.
|
||||
type ErrorCodeAttribute struct {
|
||||
Code int
|
||||
Code ErrorCode
|
||||
Reason []byte
|
||||
}
|
||||
|
||||
// constants for ERROR-CODE encoding.
|
||||
const (
|
||||
errorCodeReasonStart = 4
|
||||
errorCodeClassByte = 2
|
||||
errorCodeNumberByte = 3
|
||||
errorCodeReasonMaxB = 763
|
||||
errorCodeModulo = 100
|
||||
)
|
||||
|
||||
// AddTo adds ERROR-CODE to m.
|
||||
func (c *ErrorCodeAttribute) AddTo(m *Message) error {
|
||||
value := make([]byte,
|
||||
errorCodeReasonStart, errorCodeReasonMaxB,
|
||||
)
|
||||
number := byte(c.Code % errorCodeModulo) // error code modulo 100
|
||||
class := byte(c.Code / errorCodeModulo) // hundred digit
|
||||
value[errorCodeClassByte] = class
|
||||
value[errorCodeNumberByte] = number
|
||||
value = append(value, c.Reason...)
|
||||
m.Add(AttrErrorCode, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFrom decodes ERROR-CODE from m.
|
||||
func (c *ErrorCodeAttribute) GetFrom(m *Message) error {
|
||||
v, err := m.Get(AttrErrorCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
class = uint16(v[errorCodeClassByte])
|
||||
number = uint16(v[errorCodeNumberByte])
|
||||
code = int(class*errorCodeModulo + number)
|
||||
reason = v[errorCodeReasonStart:]
|
||||
)
|
||||
c.Code = ErrorCode(code)
|
||||
c.Reason = reason
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrorCode is code for ERROR-CODE attribute.
|
||||
type ErrorCode int
|
||||
|
||||
// ErrNoDefaultReason means that default reason for provided error code
|
||||
// is not defined in RFC.
|
||||
const ErrNoDefaultReason Error = "No default reason for ErrorCode"
|
||||
|
||||
// AddTo adds ERROR-CODE with default reason to m. If there
|
||||
// is no default reason, returns ErrNoDefaultReason.
|
||||
func (c ErrorCode) AddTo(m *Message) error {
|
||||
reason := errorReasons[c]
|
||||
if reason == nil {
|
||||
return ErrNoDefaultReason
|
||||
}
|
||||
a := &ErrorCodeAttribute{
|
||||
Code: c,
|
||||
Reason: reason,
|
||||
}
|
||||
return a.AddTo(m)
|
||||
}
|
||||
|
||||
// Possible error codes.
|
||||
const (
|
||||
CodeTryAlternate = 300
|
||||
CodeBadRequest = 400
|
||||
CodeUnauthorised = 401
|
||||
CodeUnknownAttribute = 420
|
||||
CodeStaleNonce = 428
|
||||
CodeRoleConflict = 478
|
||||
CodeServerError = 500
|
||||
CodeTryAlternate ErrorCode = 300
|
||||
CodeBadRequest ErrorCode = 400
|
||||
CodeUnauthorised ErrorCode = 401
|
||||
CodeUnknownAttribute ErrorCode = 420
|
||||
CodeStaleNonce ErrorCode = 428
|
||||
CodeRoleConflict ErrorCode = 478
|
||||
CodeServerError ErrorCode = 500
|
||||
)
|
||||
|
||||
var errorReasons = map[int]string{
|
||||
CodeTryAlternate: "Try Alternate",
|
||||
CodeBadRequest: "Bad Request",
|
||||
CodeUnauthorised: "Unauthorised",
|
||||
CodeUnknownAttribute: "Unknown Attribute",
|
||||
CodeStaleNonce: "Stale Nonce",
|
||||
CodeServerError: "Server Error",
|
||||
CodeRoleConflict: "Role Conflict",
|
||||
}
|
||||
|
||||
// Reason returns recommended reason string.
|
||||
func (c ErrorCode) Reason() string {
|
||||
reason, ok := errorReasons[int(c)]
|
||||
if !ok {
|
||||
return "Unknown Error"
|
||||
}
|
||||
return reason
|
||||
var errorReasons = map[ErrorCode][]byte{
|
||||
CodeTryAlternate: []byte("Try Alternate"),
|
||||
CodeBadRequest: []byte("Bad Request"),
|
||||
CodeUnauthorised: []byte("Unauthorised"),
|
||||
CodeUnknownAttribute: []byte("Unknown Attribute"),
|
||||
CodeStaleNonce: []byte("Stale Nonce"),
|
||||
CodeServerError: []byte("Server Error"),
|
||||
CodeRoleConflict: []byte("Role Conflict"),
|
||||
}
|
||||
|
@@ -1,26 +0,0 @@
|
||||
package stun
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestErrorCode_Reason(t *testing.T) {
|
||||
codes := [...]ErrorCode{
|
||||
CodeTryAlternate,
|
||||
CodeBadRequest,
|
||||
CodeUnauthorised,
|
||||
CodeUnknownAttribute,
|
||||
CodeStaleNonce,
|
||||
CodeRoleConflict,
|
||||
CodeServerError,
|
||||
}
|
||||
for _, code := range codes {
|
||||
if code.Reason() == "Unknown Error" {
|
||||
t.Error(code, "should not be unknown")
|
||||
}
|
||||
if len(code.Reason()) == 0 {
|
||||
t.Error(code, "should not be blank")
|
||||
}
|
||||
}
|
||||
if ErrorCode(999).Reason() != "Unknown Error" {
|
||||
t.Error("999 error should be Unknown")
|
||||
}
|
||||
}
|
68
attribute_fingerprint.go
Normal file
68
attribute_fingerprint.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
// FingerprintAttr represent FINGERPRINT attribute.
|
||||
type FingerprintAttr struct{}
|
||||
|
||||
// CRCMismatch represents CRC check error.
|
||||
type CRCMismatch struct {
|
||||
Expected uint32
|
||||
Actual uint32
|
||||
}
|
||||
|
||||
func (m CRCMismatch) Error() string {
|
||||
return fmt.Sprintf("CRC mismatch: %x (expected) != %x (actual)",
|
||||
m.Expected,
|
||||
m.Actual,
|
||||
)
|
||||
}
|
||||
|
||||
// Fingerprint is shorthand for FingerprintAttr.
|
||||
var Fingerprint = &FingerprintAttr{}
|
||||
|
||||
const (
|
||||
fingerprintXORValue uint32 = 0x5354554e
|
||||
fingerprintSize = 4 // 32 bit
|
||||
)
|
||||
|
||||
// FingerprintValue returns CRC32 of m XOR-ed by 0x5354554e.
|
||||
func FingerprintValue(b []byte) uint32 {
|
||||
return crc32.ChecksumIEEE(b) ^ fingerprintXORValue // XOR
|
||||
}
|
||||
|
||||
// AddTo adds fingerprint to message.
|
||||
func (FingerprintAttr) AddTo(m *Message) error {
|
||||
l := m.Length
|
||||
// length in header should include size of fingerprint attribute
|
||||
m.Length += fingerprintSize + attributeHeaderSize // increasing length
|
||||
m.WriteLength() // writing Length to Raw
|
||||
b := make([]byte, fingerprintSize)
|
||||
v := FingerprintValue(m.Raw)
|
||||
bin.PutUint32(b, v)
|
||||
m.Length = l
|
||||
m.Add(AttrFingerprint, b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check reads fingerprint value from m and checks it, returning error if any.
|
||||
// Can return *DecodeErr, ErrAttributeNotFound and *CRCMismatch.
|
||||
func (FingerprintAttr) Check(m *Message) error {
|
||||
v, err := m.Get(AttrFingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v) != fingerprintSize {
|
||||
return newDecodeErr("message", "fingerprint", "bad length")
|
||||
}
|
||||
val := bin.Uint32(v)
|
||||
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
|
||||
expected := FingerprintValue(m.Raw[:attrStart])
|
||||
if expected != val {
|
||||
return &CRCMismatch{Expected: expected, Actual: val}
|
||||
}
|
||||
return nil
|
||||
}
|
31
attribute_software.go
Normal file
31
attribute_software.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package stun
|
||||
|
||||
// Software is SOFTWARE attribute.
|
||||
type Software struct {
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
func (s *Software) String() string {
|
||||
return string(s.Raw)
|
||||
}
|
||||
|
||||
// NewSoftware returns *Software from string.
|
||||
func NewSoftware(software string) *Software {
|
||||
return &Software{Raw: []byte(software)}
|
||||
}
|
||||
|
||||
// AddTo adds Software attribute to m.
|
||||
func (s *Software) AddTo(m *Message) error {
|
||||
m.Add(AttrSoftware, m.Raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFrom decodes Software from m.
|
||||
func (s *Software) GetFrom(m *Message) error {
|
||||
v, err := m.Get(AttrSoftware)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Raw = v
|
||||
return nil
|
||||
}
|
118
attribute_xoraddr.go
Normal file
118
attribute_xoraddr.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
familyIPv4 byte = 0x01
|
||||
familyIPv6 byte = 0x02
|
||||
)
|
||||
|
||||
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
||||
type XORMappedAddress struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
func (a XORMappedAddress) String() string {
|
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||
}
|
||||
|
||||
// Is p all zeros?
|
||||
func isZeros(p net.IP) bool {
|
||||
for i := 0; i < len(p); i++ {
|
||||
if p[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}.
|
||||
const ErrBadIPLength Error = "invalid length if IP value"
|
||||
|
||||
// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength
|
||||
// if len(a.IP) is invalid.
|
||||
func (a *XORMappedAddress) AddTo(m *Message) error {
|
||||
var (
|
||||
family = familyIPv4
|
||||
ip = a.IP
|
||||
)
|
||||
if len(a.IP) == net.IPv6len {
|
||||
// Optimized for performance. See net.IP.To4 method.
|
||||
if isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
|
||||
ip = ip[12:16] // like in ip.To4()
|
||||
} else {
|
||||
family = familyIPv6
|
||||
}
|
||||
} else if len(ip) != net.IPv4len {
|
||||
return ErrBadIPLength
|
||||
}
|
||||
value := make([]byte, 32+128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
xorValue := make([]byte, net.IPv6len)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||
bin.PutUint16(value[0:2], uint16(family))
|
||||
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
|
||||
xorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||
m.Add(AttrXORMappedAddress, value[:4+len(ip)])
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFrom decodes XOR-MAPPED-ADDRESS attribute in message and returns
|
||||
// error if any. While decoding, a.IP is reused if possible and can be
|
||||
// rendered to invalid state (e.g. if a.IP was set to IPv6 and then
|
||||
// IPv4 value were decoded into it), be careful.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// expectedIP := net.ParseIP("213.141.156.236")
|
||||
// expectedIP.String() // 213.141.156.236, 16 bytes, first 12 of them are zeroes
|
||||
// expectedPort := 21254
|
||||
// addr := &XORMappedAddress{
|
||||
// IP: expectedIP,
|
||||
// Port: expectedPort,
|
||||
// }
|
||||
// // addr were added to message that is decoded as newMessage
|
||||
// // ...
|
||||
//
|
||||
// addr.GetFrom(newMessage)
|
||||
// addr.IP.String() // 213.141.156.236, net.IPv4Len
|
||||
// expectedIP.String() // d58d:9cec::ffff:d58d:9cec, 16 bytes, first 4 are IPv4
|
||||
// // now we have len(expectedIP) = 16 and len(addr.IP) = 4.
|
||||
func (a *XORMappedAddress) GetFrom(m *Message) error {
|
||||
v, err := m.Get(AttrXORMappedAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
family := byte(bin.Uint16(v[0:2]))
|
||||
if family != familyIPv6 && family != familyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
)
|
||||
}
|
||||
ipLen := net.IPv4len
|
||||
if family == familyIPv6 {
|
||||
ipLen = net.IPv6len
|
||||
}
|
||||
// Ensuring len(a.IP) == ipLen and reusing a.IP.
|
||||
if len(a.IP) < ipLen {
|
||||
a.IP = a.IP[:cap(a.IP)]
|
||||
for len(a.IP) < ipLen {
|
||||
a.IP = append(a.IP, 0)
|
||||
}
|
||||
}
|
||||
a.IP = a.IP[:ipLen]
|
||||
for i := range a.IP {
|
||||
a.IP[i] = 0
|
||||
}
|
||||
a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
||||
xorValue := make([]byte, 4+transactionIDSize)
|
||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
xorBytes(a.IP, v[4:], xorValue)
|
||||
return nil
|
||||
}
|
347
attributes.go
347
attributes.go
@@ -1,23 +1,10 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// AttrWriter wraps AddRaw method.
|
||||
type AttrWriter interface {
|
||||
AddRaw(t AttrType, v []byte)
|
||||
}
|
||||
|
||||
// AttrEncoder wraps Encode method.
|
||||
type AttrEncoder interface {
|
||||
Encode(b []byte, m *Message) (AttrType, []byte, error)
|
||||
}
|
||||
|
||||
// Attributes is list of message attributes.
|
||||
type Attributes []RawAttribute
|
||||
|
||||
@@ -77,7 +64,7 @@ const (
|
||||
AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN
|
||||
)
|
||||
|
||||
// Attributes from An Origin RawAttribute for the STUN Protocol.
|
||||
// Attributes from An Origin Attribute for the STUN Protocol.
|
||||
const (
|
||||
AttrOrigin AttrType = 0x802F
|
||||
)
|
||||
@@ -118,7 +105,7 @@ var attrNames = map[AttrType]string{
|
||||
func (t AttrType) String() string {
|
||||
s, ok := attrNames[t]
|
||||
if !ok {
|
||||
// just return hex representation of unknown attribute type
|
||||
// Just return hex representation of unknown attribute type.
|
||||
return "0x" + strconv.FormatUint(uint64(t), 16)
|
||||
}
|
||||
return s
|
||||
@@ -131,21 +118,17 @@ func (t AttrType) String() string {
|
||||
// don't understand, but cannot successfully process a message if it
|
||||
// contains comprehension-required attributes that are not
|
||||
// understood.
|
||||
//
|
||||
// TODO(ar): Decide to use pointer or non-pointer RawAttribute.
|
||||
type RawAttribute struct {
|
||||
Type AttrType
|
||||
Length uint16 // ignored while encoding
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// Encode implements AttrEncoder.
|
||||
func (a *RawAttribute) Encode(m *Message) ([]byte, error) {
|
||||
return m.Raw, nil
|
||||
}
|
||||
|
||||
// Decode implements AttrDecoder.
|
||||
func (a *RawAttribute) Decode(v []byte, m *Message) error {
|
||||
a.Value = v
|
||||
a.Length = uint16(len(v))
|
||||
// AddTo adds RawAttribute to m.
|
||||
func (a *RawAttribute) AddTo(m *Message) error {
|
||||
m.Add(a.Type, m.Raw)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -172,319 +155,17 @@ func (a RawAttribute) String() string {
|
||||
return fmt.Sprintf("%s: %x", a.Type, a.Value)
|
||||
}
|
||||
|
||||
// getAttrValue returns byte slice that represents attribute value,
|
||||
// if there is no value attribute with shuck type,
|
||||
// ErrAttributeNotFound means that attribute with provided attribute
|
||||
// type does not exist in message.
|
||||
const ErrAttributeNotFound Error = "Attribute not found"
|
||||
|
||||
// Get returns byte slice that represents attribute value,
|
||||
// if there is no attribute with such type,
|
||||
// ErrAttributeNotFound is returned.
|
||||
func (m *Message) getAttrValue(t AttrType) ([]byte, error) {
|
||||
func (m *Message) Get(t AttrType) ([]byte, error) {
|
||||
v, ok := m.Attributes.Get(t)
|
||||
if !ok {
|
||||
return nil, ErrAttributeNotFound
|
||||
}
|
||||
return v.Value, nil
|
||||
}
|
||||
|
||||
// AddSoftware adds SOFTWARE attribute with value from string.
|
||||
// Deprecated: use AddRaw.
|
||||
func (m *Message) AddSoftware(software string) {
|
||||
m.AddRaw(AttrSoftware, []byte(software))
|
||||
}
|
||||
|
||||
// Set sets the value of attribute if it presents.
|
||||
func (m *Message) Set(a AttrEncoder) error {
|
||||
var (
|
||||
v []byte
|
||||
err error
|
||||
t AttrType
|
||||
)
|
||||
t, v, err = a.Encode(v, m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf, err := m.getAttrValue(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v) != len(buf) {
|
||||
return ErrBadSetLength
|
||||
}
|
||||
copy(buf, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
|
||||
// If not found, returns nil.
|
||||
func (m *Message) GetSoftwareBytes() []byte {
|
||||
v, ok := m.Attributes.Get(AttrSoftware)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return v.Value
|
||||
}
|
||||
|
||||
// GetSoftware returns SOFTWARE attribute value in string.
|
||||
// If not found, returns blank string.
|
||||
// Deprecated.
|
||||
func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) }
|
||||
|
||||
// Address family values.
|
||||
const (
|
||||
FamilyIPv4 byte = 0x01
|
||||
FamilyIPv6 byte = 0x02
|
||||
)
|
||||
|
||||
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
||||
type XORMappedAddress struct {
|
||||
ip net.IP
|
||||
port int
|
||||
}
|
||||
|
||||
// Encode implements AttrEncoder.
|
||||
func (a *XORMappedAddress) Encode(buf []byte, m *Message) (AttrType, []byte, error) {
|
||||
// X-Port is computed by taking the mapped port in host byte order,
|
||||
// XOR’ing it with the most significant 16 bits of the magic cookie, and
|
||||
// then the converting the result to network byte order.
|
||||
family := FamilyIPv6
|
||||
ip := a.ip
|
||||
port := a.port
|
||||
if ipV4 := ip.To4(); ipV4 != nil {
|
||||
ip = ipV4
|
||||
family = FamilyIPv4
|
||||
}
|
||||
value := make([]byte, 32+128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
xorValue := make([]byte, net.IPv6len)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
|
||||
port ^= magicCookie >> 16
|
||||
binary.BigEndian.PutUint16(value[0:2], uint16(family))
|
||||
binary.BigEndian.PutUint16(value[2:4], uint16(port))
|
||||
xorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||
buf = append(buf, value...)
|
||||
return AttrXORMappedAddress, buf, nil
|
||||
}
|
||||
|
||||
// Decode implements AttrDecoder.
|
||||
// TODO(ar): fix signature.
|
||||
func (a *XORMappedAddress) Decode(v []byte, m *Message) error {
|
||||
// X-Port is computed by taking the mapped port in host byte order,
|
||||
// XOR’ing it with the most significant 16 bits of the magic cookie, and
|
||||
// then the converting the result to network byte order.
|
||||
v, err := m.getAttrValue(AttrXORMappedAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
family := byte(binary.BigEndian.Uint16(v[0:2]))
|
||||
if family != FamilyIPv6 && family != FamilyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
)
|
||||
}
|
||||
ipLen := net.IPv4len
|
||||
if family == FamilyIPv6 {
|
||||
ipLen = net.IPv6len
|
||||
}
|
||||
ip := net.IP(m.allocBuffer(ipLen))
|
||||
a.port = int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
||||
xorValue := make([]byte, 128)
|
||||
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
xorBytes(ip, v[4:], xorValue)
|
||||
a.ip = ip
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddXORMappedAddress adds XOR MAPPED ADDRESS attribute to message.
|
||||
// Deprecated: use AddRaw.
|
||||
func (m *Message) AddXORMappedAddress(ip net.IP, port int) {
|
||||
// X-Port is computed by taking the mapped port in host byte order,
|
||||
// XOR’ing it with the most significant 16 bits of the magic cookie, and
|
||||
// then the converting the result to network byte order.
|
||||
family := FamilyIPv6
|
||||
if ipV4 := ip.To4(); ipV4 != nil {
|
||||
ip = ipV4
|
||||
family = FamilyIPv4
|
||||
}
|
||||
value := make([]byte, 32+128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
xorValue := make([]byte, net.IPv6len)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
|
||||
port ^= magicCookie >> 16
|
||||
binary.BigEndian.PutUint16(value[0:2], uint16(family))
|
||||
binary.BigEndian.PutUint16(value[2:4], uint16(port))
|
||||
xorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||
m.AddRaw(AttrXORMappedAddress, value[:4+len(ip)])
|
||||
}
|
||||
|
||||
func (m *Message) allocBuffer(size int) []byte {
|
||||
capacity := len(m.Raw) + size
|
||||
m.grow(capacity)
|
||||
m.Raw = m.Raw[:capacity]
|
||||
return m.Raw[len(m.Raw)-size:]
|
||||
}
|
||||
|
||||
// GetXORMappedAddress returns ip, port from attribute and error if any.
|
||||
// Value for ip is valid until Message is released or underlying buffer is
|
||||
// corrupted. Returns *DecodeError or ErrAttributeNotFound.
|
||||
// Deprecated: use GetRaw.
|
||||
func (m *Message) GetXORMappedAddress() (net.IP, int, error) {
|
||||
// X-Port is computed by taking the mapped port in host byte order,
|
||||
// XOR’ing it with the most significant 16 bits of the magic cookie, and
|
||||
// then the converting the result to network byte order.
|
||||
v, err := m.getAttrValue(AttrXORMappedAddress)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
family := byte(binary.BigEndian.Uint16(v[0:2]))
|
||||
if family != FamilyIPv6 && family != FamilyIPv4 {
|
||||
return nil, 0, newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
)
|
||||
}
|
||||
ipLen := net.IPv4len
|
||||
if family == FamilyIPv6 {
|
||||
ipLen = net.IPv6len
|
||||
}
|
||||
ip := net.IP(m.allocBuffer(ipLen))
|
||||
port := int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16)
|
||||
xorValue := make([]byte, 128)
|
||||
binary.BigEndian.PutUint32(xorValue[0:4], magicCookie)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
xorBytes(ip, v[4:], xorValue)
|
||||
return ip, port, nil
|
||||
}
|
||||
|
||||
// constants for ERROR-CODE encoding.
|
||||
const (
|
||||
errorCodeReasonStart = 4
|
||||
errorCodeClassByte = 2
|
||||
errorCodeNumberByte = 3
|
||||
errorCodeReasonMaxB = 763
|
||||
errorCodeModulo = 100
|
||||
)
|
||||
|
||||
// AddErrorCode adds ERROR-CODE attribute to message.
|
||||
//
|
||||
// The reason phrase MUST be a UTF-8 [RFC 3629] encoded
|
||||
// sequence of less than 128 characters (which can be as long as 763
|
||||
// bytes).
|
||||
// Deprecated: use AddRaw.
|
||||
func (m *Message) AddErrorCode(code int, reason string) {
|
||||
value := make([]byte,
|
||||
errorCodeReasonStart, errorCodeReasonMaxB+errorCodeReasonStart,
|
||||
)
|
||||
number := byte(code % errorCodeModulo) // error code modulo 100
|
||||
class := byte(code / errorCodeModulo) // hundred digit
|
||||
value[errorCodeClassByte] = class
|
||||
value[errorCodeNumberByte] = number
|
||||
value = append(value, reason...)
|
||||
m.AddRaw(AttrErrorCode, value)
|
||||
}
|
||||
|
||||
// AddErrorCodeDefault is wrapper for AddErrorCode that uses recommended
|
||||
// reason string from RFC. If error code is unknown, reason will be "Unknown
|
||||
// Error".
|
||||
// Deprecated: use AddRaw.
|
||||
func (m *Message) AddErrorCodeDefault(code int) {
|
||||
m.AddErrorCode(code, ErrorCode(code).Reason())
|
||||
}
|
||||
|
||||
// GetErrorCode returns ERROR-CODE code, reason and decode error if any.
|
||||
// Deprecated: use GetRaw.
|
||||
func (m *Message) GetErrorCode() (int, []byte, error) {
|
||||
v, err := m.getAttrValue(AttrErrorCode)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
var (
|
||||
class = uint16(v[errorCodeClassByte])
|
||||
number = uint16(v[errorCodeNumberByte])
|
||||
code = int(class*errorCodeModulo + number)
|
||||
reason = v[errorCodeReasonStart:]
|
||||
)
|
||||
return code, reason, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// ErrAttributeNotFound means that there is no such attribute.
|
||||
ErrAttributeNotFound Error = "Attribute not found"
|
||||
|
||||
// ErrBadSetLength means that previous attribute value length differs from
|
||||
// new value.
|
||||
ErrBadSetLength Error = "Previous attribute length is different"
|
||||
)
|
||||
|
||||
// Software is SOFTWARE attribute.
|
||||
type Software struct {
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
func (s Software) String() string {
|
||||
return string(s.Raw)
|
||||
}
|
||||
|
||||
// NewSoftware returns *Software from string.
|
||||
func NewSoftware(software string) *Software {
|
||||
return &Software{Raw: []byte(software)}
|
||||
}
|
||||
|
||||
// Encode implements AttrEncoder.
|
||||
func (s *Software) Encode(b []byte, m *Message) (AttrType, []byte, error) {
|
||||
return AttrSoftware, append(b, s.Raw...), nil
|
||||
}
|
||||
|
||||
// Decode implements AttrDecoder.
|
||||
func (s *Software) Decode(v []byte, m *Message) error {
|
||||
s.Raw = v
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
fingerprintXORValue uint32 = 0x5354554e
|
||||
)
|
||||
|
||||
// Fingerprint represents FINGERPRINT attribute.
|
||||
type Fingerprint struct {
|
||||
Value uint32 // CRC-32 of message XOR-ed with 0x5354554e
|
||||
}
|
||||
|
||||
const (
|
||||
fingerprintSize = 4 // 32 bit
|
||||
)
|
||||
|
||||
// AddTo adds fingerprint to message.
|
||||
func (f *Fingerprint) AddTo(m *Message) error {
|
||||
l := m.Length
|
||||
// length in header should include size of fingerprint attribute
|
||||
m.Length += fingerprintSize + attributeHeaderSize // increasing length
|
||||
m.WriteLength() // writing Length to Raw
|
||||
b := make([]byte, fingerprintSize)
|
||||
f.Value = crc32.ChecksumIEEE(m.Raw) ^ fingerprintXORValue // XOR
|
||||
bin.PutUint32(b, f.Value)
|
||||
m.Length = l
|
||||
m.AddRaw(AttrFingerprint, b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check reads fingerprint value from m and checks it, returning error if any.
|
||||
// Can return *DecodeErr, ErrAttributeNotFound, ErrCRCMissmatch.
|
||||
func (f *Fingerprint) Check(m *Message) error {
|
||||
v, err := m.getAttrValue(AttrFingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(v) != fingerprintSize {
|
||||
return newDecodeErr("message", "fingerprint", "bad length")
|
||||
}
|
||||
f.Value = bin.Uint32(v)
|
||||
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)
|
||||
expected := crc32.ChecksumIEEE(m.Raw[:attrStart]) ^ fingerprintXORValue
|
||||
if expected != f.Value {
|
||||
return ErrCRCMissmatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrCRCMissmatch means that calculated fingerprint attribute differs from
|
||||
// expected one.
|
||||
const ErrCRCMissmatch Error = "CRC32 missmatch: bad fingerprint value"
|
||||
|
@@ -10,21 +10,24 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessage_AddSoftware(t *testing.T) {
|
||||
func TestSoftware_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
v := "Client v0.0.1"
|
||||
m.AddRaw(AttrSoftware, []byte(v))
|
||||
m.Add(AttrSoftware, []byte(v))
|
||||
m.WriteHeader()
|
||||
|
||||
m2 := &Message{
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
software := new(Software)
|
||||
if _, err := m2.ReadFrom(m.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
vRead := m.GetSoftware()
|
||||
if vRead != v {
|
||||
t.Errorf("Expected %s, got %s.", v, vRead)
|
||||
if err := software.GetFrom(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if software.String() != v {
|
||||
t.Errorf("Expected %q, got %q.", v, software)
|
||||
}
|
||||
|
||||
sAttr, ok := m.Attributes.Get(AttrSoftware)
|
||||
@@ -37,29 +40,18 @@ func TestMessage_AddSoftware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_GetSoftware(t *testing.T) {
|
||||
m := New()
|
||||
v := m.GetSoftware()
|
||||
if v != "" {
|
||||
t.Errorf("%s should be blank.", v)
|
||||
}
|
||||
vByte := m.GetSoftwareBytes()
|
||||
if vByte != nil {
|
||||
t.Errorf("%s should be nil.", vByte)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessage_AddXORMappedAddress(b *testing.B) {
|
||||
func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
|
||||
m := New()
|
||||
b.ReportAllocs()
|
||||
ip := net.ParseIP("192.168.1.32")
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.AddXORMappedAddress(ip, 3654)
|
||||
addr := &XORMappedAddress{IP: ip, Port: 3654}
|
||||
addr.AddTo(m)
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessage_GetXORMappedAddress(b *testing.B) {
|
||||
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
|
||||
m := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
@@ -70,10 +62,13 @@ func BenchmarkMessage_GetXORMappedAddress(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
m.Add(AttrXORMappedAddress, addrValue)
|
||||
addr := new(XORMappedAddress)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.AddRaw(AttrXORMappedAddress, addrValue)
|
||||
m.GetXORMappedAddress()
|
||||
m.Reset()
|
||||
if err := addr.GetFrom(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,16 +83,16 @@ func TestMessage_GetXORMappedAddress(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m.AddRaw(AttrXORMappedAddress, addrValue)
|
||||
ip, port, err := m.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
m.Add(AttrXORMappedAddress, addrValue)
|
||||
addr := new(XORMappedAddress)
|
||||
if err = addr.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !ip.Equal(net.ParseIP("213.141.156.236")) {
|
||||
t.Error("bad ip", ip, "!=", "213.141.156.236")
|
||||
if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
|
||||
t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
|
||||
}
|
||||
if port != 48583 {
|
||||
t.Error("bad port", port, "!=", 48583)
|
||||
if addr.Port != 48583 {
|
||||
t.Error("bad Port", addr.Port, "!=", 48583)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,13 +105,15 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
addr := new(XORMappedAddress)
|
||||
|
||||
_, _, err = m.GetXORMappedAddress()
|
||||
if err == nil {
|
||||
if err = addr.GetFrom(m); err == nil {
|
||||
t.Fatal(err, "should be nil")
|
||||
}
|
||||
|
||||
m.AddXORMappedAddress(expectedIP, expectedPort)
|
||||
addr.IP = expectedIP
|
||||
addr.Port = expectedPort
|
||||
addr.AddTo(m)
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
@@ -124,8 +121,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) {
|
||||
if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, err = m.GetXORMappedAddress()
|
||||
if err == nil {
|
||||
if err = addr.GetFrom(m); err == nil {
|
||||
t.Fatal(err, "should not be nil")
|
||||
}
|
||||
}
|
||||
@@ -139,22 +135,26 @@ func TestMessage_AddXORMappedAddress(t *testing.T) {
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
m.AddXORMappedAddress(expectedIP, expectedPort)
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.ParseIP("213.141.156.236"),
|
||||
Port: expectedPort,
|
||||
}
|
||||
if err = addr.AddTo(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
if _, err = mRes.Write(m.Raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ip, port, err := m.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
if err = addr.GetFrom(mRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ip.Equal(expectedIP) {
|
||||
t.Error("bad ip", ip, "!=", expectedIP)
|
||||
if !addr.IP.Equal(expectedIP) {
|
||||
t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
|
||||
}
|
||||
if port != expectedPort {
|
||||
t.Error("bad port", port, "!=", expectedPort)
|
||||
if addr.Port != expectedPort {
|
||||
t.Error("bad Port", addr.Port, "!=", expectedPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,30 +167,47 @@ func TestMessage_AddXORMappedAddressV6(t *testing.T) {
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
|
||||
expectedPort := 21254
|
||||
m.AddXORMappedAddress(expectedIP, expectedPort)
|
||||
addr := &XORMappedAddress{
|
||||
IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"),
|
||||
Port: 21254,
|
||||
}
|
||||
addr.AddTo(m)
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ip, port, err := m.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
gotAddr := new(XORMappedAddress)
|
||||
if err = gotAddr.GetFrom(m); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ip.Equal(expectedIP) {
|
||||
t.Error("bad ip", ip, "!=", expectedIP)
|
||||
if !gotAddr.IP.Equal(expectedIP) {
|
||||
t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
|
||||
}
|
||||
if port != expectedPort {
|
||||
t.Error("bad port", port, "!=", expectedPort)
|
||||
if gotAddr.Port != expectedPort {
|
||||
t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessage_AddErrorCode(b *testing.B) {
|
||||
func BenchmarkErrorCode_AddTo(b *testing.B) {
|
||||
m := New()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.AddErrorCode(404, "Not found")
|
||||
CodeStaleNonce.AddTo(m)
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) {
|
||||
m := New()
|
||||
b.ReportAllocs()
|
||||
a := &ErrorCodeAttribute{
|
||||
Code: 404,
|
||||
Reason: []byte("not found!"),
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
a.AddTo(m)
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
@@ -202,51 +219,27 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedCode := 404
|
||||
expectedReason := "Not found"
|
||||
m.AddErrorCode(expectedCode, expectedReason)
|
||||
expectedCode := ErrorCode(428)
|
||||
expectedReason := "Stale Nonce"
|
||||
CodeStaleNonce.AddTo(m)
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
code, reason, err := mRes.GetErrorCode()
|
||||
errCodeAttr := new(ErrorCodeAttribute)
|
||||
if err = errCodeAttr.GetFrom(mRes); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
code := errCodeAttr.Code
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if code != expectedCode {
|
||||
t.Error("bad code", code)
|
||||
}
|
||||
if string(reason) != expectedReason {
|
||||
t.Error("bad reason", string(reason))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_AddErrorCodeDefault(t *testing.T) {
|
||||
m := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedCode := 500
|
||||
expectedReason := "Server Error"
|
||||
m.AddErrorCodeDefault(expectedCode)
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
code, reason, err := mRes.GetErrorCode()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if code != expectedCode {
|
||||
t.Error("bad code", code)
|
||||
}
|
||||
if string(reason) != expectedReason {
|
||||
t.Error("bad reason", string(reason))
|
||||
if string(errCodeAttr.Reason) != expectedReason {
|
||||
t.Error("bad reason", string(errCodeAttr.Reason))
|
||||
}
|
||||
}
|
||||
|
11
errors.go
11
errors.go
@@ -15,6 +15,12 @@ type DecodeErr struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// IsInvalidCookie returns true if error means that magic cookie
|
||||
// value is invalid.
|
||||
func (e DecodeErr) IsInvalidCookie() bool {
|
||||
return e.Place == DecodeErrPlace{"message", "cookie"}
|
||||
}
|
||||
|
||||
// IsPlaceParent reports if error place parent is p.
|
||||
func (e DecodeErr) IsPlaceParent(p string) bool {
|
||||
return e.Place.Parent == p
|
||||
@@ -50,3 +56,8 @@ func newDecodeErr(parent, children, message string) *DecodeErr {
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ar): rewrite errors to be more precise.
|
||||
func newAttrDecodeErr(children, message string) *DecodeErr {
|
||||
return newDecodeErr("attribute", children, message)
|
||||
}
|
||||
|
4
fuzz.go
4
fuzz.go
@@ -12,12 +12,12 @@ func FuzzMessage(data []byte) int {
|
||||
// fuzzer dont know about cookies
|
||||
binary.BigEndian.PutUint32(data[4:8], magicCookie)
|
||||
// trying to read data as message
|
||||
if _, err := m.ReadBytes(data); err != nil {
|
||||
if _, err := m.Write(data); err != nil {
|
||||
return 0
|
||||
}
|
||||
m.WriteHeader()
|
||||
m2 := New()
|
||||
if _, err := m2.ReadBytes(m2.Raw); err != nil {
|
||||
if _, err := m2.Write(m2.Raw); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if m2.TransactionID != m.TransactionID {
|
||||
|
491
message.go
Normal file
491
message.go
Normal file
@@ -0,0 +1,491 @@
|
||||
// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389.
|
||||
//
|
||||
// Definitions
|
||||
//
|
||||
// STUN Agent: A STUN agent is an entity that implements the STUN
|
||||
// protocol. The entity can be either a STUN client or a STUN
|
||||
// server.
|
||||
//
|
||||
// STUN Client: A STUN client is an entity that sends STUN requests and
|
||||
// receives STUN responses. A STUN client can also send indications.
|
||||
// In this specification, the terms STUN client and client are
|
||||
// synonymous.
|
||||
//
|
||||
// STUN Server: A STUN server is an entity that receives STUN requests
|
||||
// and sends STUN responses. A STUN server can also send
|
||||
// indications. In this specification, the terms STUN server and
|
||||
// server are synonymous.
|
||||
//
|
||||
// Transport Address: The combination of an IP address and Port number
|
||||
// (such as a UDP or TCP Port number).
|
||||
package stun
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
// magicCookie is fixed value that aids in distinguishing STUN packets
|
||||
// from packets of other protocols when STUN is multiplexed with those
|
||||
// other protocols on the same Port.
|
||||
//
|
||||
// The magic cookie field MUST contain the fixed value 0x2112A442 in
|
||||
// network byte order.
|
||||
//
|
||||
// Defined in "STUN Message Structure", section 6.
|
||||
magicCookie = 0x2112A442
|
||||
attributeHeaderSize = 4
|
||||
messageHeaderSize = 20
|
||||
transactionIDSize = 12 // 96 bit
|
||||
)
|
||||
|
||||
// NewTransactionID returns new random transaction ID using crypto/rand
|
||||
// as source.
|
||||
func NewTransactionID() (b [transactionIDSize]byte) {
|
||||
_, err := rand.Read(b[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// IsMessage returns true if b looks like STUN message.
|
||||
// Useful for multiplexing. IsMessage does not guarantee
|
||||
// that decoding will be successful.
|
||||
func IsMessage(b []byte) bool {
|
||||
return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
|
||||
}
|
||||
|
||||
// New returns *Message with pre-allocated Raw.
|
||||
func New() *Message {
|
||||
const defaultRawCapacity = 120
|
||||
return &Message{
|
||||
Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
|
||||
}
|
||||
}
|
||||
|
||||
// Message represents a single STUN packet. It uses aggressive internal
|
||||
// buffering to enable zero-allocation encoding and decoding,
|
||||
// so there are some usage constraints:
|
||||
//
|
||||
// * Message and its fields is valid only until AcquireMessage call.
|
||||
type Message struct {
|
||||
Type MessageType
|
||||
Length uint32 // len(Raw) not including header
|
||||
TransactionID [transactionIDSize]byte
|
||||
Attributes Attributes
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// NewTransactionID sets m.TransactionID to random value from crypto/rand
|
||||
// and returns error if any.
|
||||
func (m *Message) NewTransactionID() error {
|
||||
_, err := rand.Read(m.TransactionID[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (m Message) String() string {
|
||||
return fmt.Sprintf("%s l=%d attrs=%d id=%s",
|
||||
m.Type,
|
||||
m.Length,
|
||||
len(m.Attributes),
|
||||
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
|
||||
)
|
||||
}
|
||||
|
||||
// Reset resets Message, attributes and underlying buffer length.
|
||||
func (m *Message) Reset() {
|
||||
m.Raw = m.Raw[:0]
|
||||
m.Length = 0
|
||||
m.Attributes = m.Attributes[:0]
|
||||
}
|
||||
|
||||
// grow ensures that internal buffer will fit v more bytes and
|
||||
// increases it capacity if necessary.
|
||||
func (m *Message) grow(v int) {
|
||||
// Not performing any optimizations here
|
||||
// (e.g. preallocate len(buf) * 2 to reduce allocations)
|
||||
// because they are already done by []byte implementation.
|
||||
n := len(m.Raw) + v
|
||||
for cap(m.Raw) < n {
|
||||
m.Raw = append(m.Raw, 0)
|
||||
}
|
||||
m.Raw = m.Raw[:n]
|
||||
}
|
||||
|
||||
// Add appends new attribute to message. Not goroutine-safe.
|
||||
//
|
||||
// Value of attribute is copied to internal buffer so
|
||||
// it is safe to reuse v.
|
||||
func (m *Message) Add(t AttrType, v []byte) {
|
||||
// Allocating buffer for TLV (type-length-value).
|
||||
// T = t, L = len(v), V = v.
|
||||
// m.Raw will look like:
|
||||
// [0:20] <- message header
|
||||
// [20:20+m.Length] <- existing message attributes
|
||||
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
|
||||
// [first:last] <- same as previous
|
||||
// [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
|
||||
// T L V
|
||||
allocSize := attributeHeaderSize + len(v) // len(TLV) = len(TL) + len(V)
|
||||
first := messageHeaderSize + int(m.Length) // first byte number
|
||||
last := first + allocSize // last byte number
|
||||
m.grow(last) // growing cap(Raw) to fit TLV
|
||||
m.Raw = m.Raw[:last] // now len(Raw) = last
|
||||
m.Length += uint32(allocSize) // rendering length change
|
||||
|
||||
// Sub-slicing internal buffer to simplify encoding.
|
||||
buf := m.Raw[first:last] // slice for TLV
|
||||
value := buf[attributeHeaderSize:] // slice for V
|
||||
attr := RawAttribute{
|
||||
Type: t, // T
|
||||
Length: uint16(len(v)), // L
|
||||
Value: value, // V
|
||||
}
|
||||
|
||||
// Encoding attribute TLV to allocated buffer.
|
||||
bin.PutUint16(buf[0:2], attr.Type.Value()) // T
|
||||
bin.PutUint16(buf[2:4], attr.Length) // L
|
||||
copy(value, v) // V
|
||||
|
||||
// Checking that attribute value needs padding.
|
||||
if attr.Length%padding != 0 {
|
||||
// Performing padding.
|
||||
bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
|
||||
last += bytesToAdd
|
||||
m.grow(last)
|
||||
// setting all padding bytes to zero
|
||||
// to prevent data leak from previous
|
||||
// data in next bytesToAdd bytes
|
||||
buf = m.Raw[last-bytesToAdd : last]
|
||||
for i := range buf {
|
||||
buf[i] = 0
|
||||
}
|
||||
m.Raw = m.Raw[:last] // increasing buffer length
|
||||
m.Length += uint32(bytesToAdd) // rendering length change
|
||||
}
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
}
|
||||
|
||||
// Equal returns true if Message b equals to m.
|
||||
// Ignores m.Raw.
|
||||
func (m *Message) Equal(b *Message) bool {
|
||||
if m.Type != b.Type {
|
||||
return false
|
||||
}
|
||||
if m.TransactionID != b.TransactionID {
|
||||
return false
|
||||
}
|
||||
if m.Length != b.Length {
|
||||
return false
|
||||
}
|
||||
for _, a := range m.Attributes {
|
||||
aB, ok := b.Attributes.Get(a.Type)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !aB.Equal(a) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4.
|
||||
func (m *Message) WriteLength() {
|
||||
_ = m.Raw[4] // early bounds check to guarantee safety of writes below
|
||||
bin.PutUint16(m.Raw[2:4], uint16(m.Length))
|
||||
}
|
||||
|
||||
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
|
||||
func (m *Message) WriteHeader() {
|
||||
if len(m.Raw) < messageHeaderSize {
|
||||
// Making WriteHeader call valid even when m.Raw
|
||||
// is nil or len(m.Raw) is less than needed for header.
|
||||
m.grow(messageHeaderSize)
|
||||
}
|
||||
_ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
|
||||
|
||||
bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
|
||||
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize)) // size of payload
|
||||
bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
|
||||
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
|
||||
}
|
||||
|
||||
// WriteAttributes encodes all m.Attributes to m.
|
||||
func (m *Message) WriteAttributes() {
|
||||
for _, a := range m.Attributes {
|
||||
m.Add(a.Type, a.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode resets m.Raw and calls WriteHeader and WriteAttributes.
|
||||
func (m *Message) Encode() {
|
||||
m.Raw = m.Raw[:0]
|
||||
m.WriteHeader()
|
||||
m.WriteAttributes()
|
||||
}
|
||||
|
||||
// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
|
||||
// call result.
|
||||
func (m *Message) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(m.Raw)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// Append appends m.Raw to v. Useful to call after encoding message.
|
||||
func (m *Message) Append(v []byte) []byte {
|
||||
return append(v, m.Raw...)
|
||||
}
|
||||
|
||||
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
|
||||
// Decodes it and return error if any. If m.Raw is too small, will return
|
||||
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
|
||||
//
|
||||
// Can return *DecodeErr while decoding too.
|
||||
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
||||
tBuf := m.Raw[:cap(m.Raw)]
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
if n, err = r.Read(tBuf); err != nil {
|
||||
return int64(n), err
|
||||
}
|
||||
m.Raw = tBuf[:n]
|
||||
return int64(n), m.Decode()
|
||||
}
|
||||
|
||||
const (
|
||||
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
|
||||
// m.Raw to read header.
|
||||
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
|
||||
)
|
||||
|
||||
// Decode decodes m.Raw into m.
|
||||
func (m *Message) Decode() error {
|
||||
// decoding message header
|
||||
buf := m.Raw
|
||||
if len(buf) < messageHeaderSize {
|
||||
return ErrUnexpectedHeaderEOF
|
||||
}
|
||||
var (
|
||||
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
|
||||
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
|
||||
cookie = binary.BigEndian.Uint32(buf[4:8])
|
||||
fullSize = messageHeaderSize + size
|
||||
)
|
||||
if cookie != magicCookie {
|
||||
msg := fmt.Sprintf(
|
||||
"%x is invalid magic cookie (should be %x)",
|
||||
cookie, magicCookie,
|
||||
)
|
||||
return newDecodeErr("message", "cookie", msg)
|
||||
}
|
||||
if len(buf) < fullSize {
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected message size)",
|
||||
len(buf), fullSize,
|
||||
)
|
||||
return newAttrDecodeErr("message", msg)
|
||||
}
|
||||
// saving header data
|
||||
m.Type.ReadValue(t)
|
||||
m.Length = uint32(size)
|
||||
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
||||
|
||||
var (
|
||||
offset = 0
|
||||
b = buf[messageHeaderSize:fullSize]
|
||||
)
|
||||
for offset < size {
|
||||
// checking that we have enough bytes to read header
|
||||
if len(b) < attributeHeaderSize {
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected header size)",
|
||||
len(b), attributeHeaderSize,
|
||||
)
|
||||
return newAttrDecodeErr("header", msg)
|
||||
}
|
||||
var (
|
||||
a = RawAttribute{
|
||||
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
|
||||
Length: bin.Uint16(b[2:4]), // second 2 bytes
|
||||
}
|
||||
aL = int(a.Length) // attribute length
|
||||
aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
|
||||
)
|
||||
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
||||
offset += attributeHeaderSize
|
||||
if len(b) < aBuffL { // checking size
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected value size)",
|
||||
len(b), aBuffL,
|
||||
)
|
||||
return newAttrDecodeErr("value", msg)
|
||||
}
|
||||
a.Value = b[:aL]
|
||||
offset += aBuffL
|
||||
b = b[aBuffL:]
|
||||
|
||||
m.Attributes = append(m.Attributes, a)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write decodes message and return error if any.
|
||||
//
|
||||
// Any error is unrecoverable, but message could be partially decoded.
|
||||
func (m *Message) Write(tBuf []byte) (int, error) {
|
||||
m.Raw = append(m.Raw[:0], tBuf...)
|
||||
return len(tBuf), m.Decode()
|
||||
}
|
||||
|
||||
// MaxPacketSize is maximum size of UDP packet that is processable in
|
||||
// this package for STUN message.
|
||||
const MaxPacketSize = 2048
|
||||
|
||||
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
|
||||
type MessageClass byte
|
||||
|
||||
// Possible values for message class in STUN Message Type.
|
||||
const (
|
||||
ClassRequest MessageClass = 0x00 // 0b00
|
||||
ClassIndication MessageClass = 0x01 // 0b01
|
||||
ClassSuccessResponse MessageClass = 0x02 // 0b10
|
||||
ClassErrorResponse MessageClass = 0x03 // 0b11
|
||||
)
|
||||
|
||||
func (c MessageClass) String() string {
|
||||
switch c {
|
||||
case ClassRequest:
|
||||
return "request"
|
||||
case ClassIndication:
|
||||
return "indication"
|
||||
case ClassSuccessResponse:
|
||||
return "success response"
|
||||
case ClassErrorResponse:
|
||||
return "error response"
|
||||
default:
|
||||
panic("unknown message class")
|
||||
}
|
||||
}
|
||||
|
||||
// Method is uint16 representation of 12-bit STUN method.
|
||||
type Method uint16
|
||||
|
||||
// Possible methods for STUN Message.
|
||||
const (
|
||||
MethodBinding Method = 0x001
|
||||
MethodAllocate Method = 0x003
|
||||
MethodRefresh Method = 0x004
|
||||
MethodSend Method = 0x006
|
||||
MethodData Method = 0x007
|
||||
MethodCreatePermission Method = 0x008
|
||||
MethodChannelBind Method = 0x009
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
switch m {
|
||||
case MethodBinding:
|
||||
return "binding"
|
||||
case MethodAllocate:
|
||||
return "allocate"
|
||||
case MethodRefresh:
|
||||
return "refresh"
|
||||
case MethodSend:
|
||||
return "send"
|
||||
case MethodData:
|
||||
return "data"
|
||||
case MethodCreatePermission:
|
||||
return "create permission"
|
||||
case MethodChannelBind:
|
||||
return "channel bind"
|
||||
default:
|
||||
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
|
||||
}
|
||||
}
|
||||
|
||||
// MessageType is STUN Message Type Field.
|
||||
type MessageType struct {
|
||||
Class MessageClass
|
||||
Method Method
|
||||
}
|
||||
|
||||
const (
|
||||
methodABits = 0xf // 0b0000000000001111
|
||||
methodBBits = 0x70 // 0b0000000001110000
|
||||
methodDBits = 0xf80 // 0b0000111110000000
|
||||
|
||||
methodBShift = 1
|
||||
methodDShift = 2
|
||||
|
||||
firstBit = 0x1
|
||||
secondBit = 0x2
|
||||
|
||||
c0Bit = firstBit
|
||||
c1Bit = secondBit
|
||||
|
||||
classC0Shift = 4
|
||||
classC1Shift = 7
|
||||
)
|
||||
|
||||
// Value returns bit representation of messageType.
|
||||
func (t MessageType) Value() uint16 {
|
||||
// 0 1
|
||||
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|
||||
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
|
||||
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// Figure 3: Format of STUN Message Type Field
|
||||
|
||||
// Warning: Abandon all hope ye who enter here.
|
||||
// Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
|
||||
m := uint16(t.Method)
|
||||
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
||||
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
||||
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
||||
|
||||
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
|
||||
m = a + (b << methodBShift) + (d << methodDShift)
|
||||
|
||||
// C0 is zero bit of C, C1 is fist bit.
|
||||
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
|
||||
// Ct = C0 << 4 + C1 << 8.
|
||||
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
|
||||
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
|
||||
// (see figure 3).
|
||||
c := uint16(t.Class)
|
||||
c0 := (c & c0Bit) << classC0Shift
|
||||
c1 := (c & c1Bit) << classC1Shift
|
||||
class := c0 + c1
|
||||
|
||||
return m + class
|
||||
}
|
||||
|
||||
// ReadValue decodes uint16 into MessageType.
|
||||
func (t *MessageType) ReadValue(v uint16) {
|
||||
// Decoding class.
|
||||
// We are taking first bit from v >> 4 and second from v >> 7.
|
||||
c0 := (v >> classC0Shift) & c0Bit
|
||||
c1 := (v >> classC1Shift) & c1Bit
|
||||
class := c0 + c1
|
||||
t.Class = MessageClass(class)
|
||||
|
||||
// Decoding method.
|
||||
a := v & methodABits // A(M0-M3)
|
||||
b := (v >> methodBShift) & methodBBits // B(M4-M6)
|
||||
d := (v >> methodDShift) & methodDBits // D(M7-M11)
|
||||
m := a + b + d
|
||||
t.Method = Method(m)
|
||||
}
|
||||
|
||||
func (t MessageType) String() string {
|
||||
return fmt.Sprintf("%s %s", t.Method, t.Class)
|
||||
}
|
13
padding.go
Normal file
13
padding.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package stun
|
||||
|
||||
const (
|
||||
padding = 4
|
||||
)
|
||||
|
||||
func nearestPaddedValueLength(l int) int {
|
||||
n := padding * (l / padding)
|
||||
if n < l {
|
||||
n += padding
|
||||
}
|
||||
return n
|
||||
}
|
@@ -6,7 +6,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ernado/stun"
|
||||
@@ -17,13 +17,14 @@ var (
|
||||
fmt.Sprintf("127.0.0.1:%d", stun.DefaultPort),
|
||||
"addr to attack",
|
||||
)
|
||||
readWorkers = flag.Int("read-workers", 1, "concurrent read workers")
|
||||
writeWorkers = flag.Int("write-workers", 1, "concurrent write workers")
|
||||
|
||||
count int64
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
runtime.GOMAXPROCS(2)
|
||||
log.SetFlags(log.Lshortfile)
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
@@ -43,8 +44,9 @@ func main() {
|
||||
},
|
||||
TransactionID: stun.NewTransactionID(),
|
||||
}
|
||||
m.AddRaw(stun.AttrSoftware, []byte("stun benchmark"))
|
||||
m.Add(stun.AttrSoftware, []byte("stun benchmark"))
|
||||
m.Encode()
|
||||
for i := 0; i < *readWorkers; i++ {
|
||||
go func() {
|
||||
mRec := stun.New()
|
||||
mRec.Raw = make([]byte, 1024)
|
||||
@@ -58,15 +60,26 @@ func main() {
|
||||
// if err := mRec.Decode(); err != nil {
|
||||
// log.Fatalln("Decode:", err)
|
||||
// }
|
||||
count++
|
||||
if count%10000 == 0 {
|
||||
fmt.Printf("%d\n", count)
|
||||
atomic.AddInt64(&count, 1)
|
||||
if c := atomic.LoadInt64(&count); c%10000 == 0 {
|
||||
fmt.Printf("%d\n", c)
|
||||
elapsed := time.Since(start)
|
||||
fmt.Println(float64(count)/elapsed.Seconds(), "per second")
|
||||
fmt.Println(float64(c)/elapsed.Seconds(), "per second")
|
||||
}
|
||||
// mRec.Reset()
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i := 1; i < *writeWorkers; i++ {
|
||||
go func() {
|
||||
for {
|
||||
_, err := c.Write(m.Raw)
|
||||
if err != nil {
|
||||
log.Fatalln("write:", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for {
|
||||
_, err := c.Write(m.Raw)
|
||||
if err != nil {
|
||||
|
@@ -33,7 +33,7 @@ func TestClient_Do(t *testing.T) {
|
||||
m := stun.New()
|
||||
m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
|
||||
m.TransactionID = stun.NewTransactionID()
|
||||
m.Add(stun.NewSoftware("cydev/stun alpha"))
|
||||
stun.NewSoftware("cydev/stun alpha").AddTo(m)
|
||||
m.WriteHeader()
|
||||
request := Request{
|
||||
Target: "stun.l.google.com:19302",
|
||||
@@ -43,11 +43,11 @@ func TestClient_Do(t *testing.T) {
|
||||
if r.Message.TransactionID != m.TransactionID {
|
||||
t.Error("transaction id messmatch")
|
||||
}
|
||||
ip, port, err := r.Message.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
addr := new(stun.XORMappedAddress)
|
||||
if err := addr.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
log.Println("got", ip, port)
|
||||
log.Println("got", addr)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@@ -190,7 +190,7 @@ func discover(c *cli.Context) error {
|
||||
Class: stun.ClassRequest,
|
||||
},
|
||||
}
|
||||
m.AddRaw(stun.AttrSoftware, software.Raw)
|
||||
m.Add(stun.AttrSoftware, software.Raw)
|
||||
m.WriteHeader()
|
||||
|
||||
request := Request{
|
||||
@@ -200,15 +200,13 @@ func discover(c *cli.Context) error {
|
||||
|
||||
return DefaultClient.Do(request, func(r Response) error {
|
||||
var (
|
||||
ip net.IP
|
||||
port int
|
||||
err error
|
||||
)
|
||||
ip, port, err = r.Message.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
addr := new(stun.XORMappedAddress)
|
||||
if err = addr.GetFrom(r.Message); err != nil {
|
||||
return errors.Wrap(err, "failed to get ip")
|
||||
}
|
||||
fmt.Println(ip, port)
|
||||
fmt.Println(addr)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@@ -23,7 +23,7 @@ func main() {
|
||||
log.Fatalln("Unable to decode bas64 value:", err)
|
||||
}
|
||||
m := stun.New()
|
||||
if _, err = m.ReadBytes(data); err != nil {
|
||||
if _, err = m.Write(data); err != nil {
|
||||
log.Fatalln("Unable to decode message:", err)
|
||||
}
|
||||
fmt.Println(m)
|
||||
|
540
stun.go
540
stun.go
@@ -16,542 +16,14 @@
|
||||
// indications. In this specification, the terms STUN server and
|
||||
// server are synonymous.
|
||||
//
|
||||
// Transport Address: The combination of an IP address and port number
|
||||
// (such as a UDP or TCP port number).
|
||||
// Transport Address: The combination of an IP address and Port number
|
||||
// (such as a UDP or TCP Port number).
|
||||
package stun
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
import "encoding/binary"
|
||||
|
||||
var (
|
||||
// bin is shorthand to binary.BigEndian.
|
||||
bin = binary.BigEndian
|
||||
)
|
||||
// bin is shorthand to binary.BigEndian.
|
||||
var bin = binary.BigEndian
|
||||
|
||||
// DefaultPort is IANA assigned port for "stun" protocol.
|
||||
// DefaultPort is IANA assigned Port for "stun" protocol.
|
||||
const DefaultPort = 3478
|
||||
|
||||
const (
|
||||
// magicCookie is fixed value that aids in distinguishing STUN packets
|
||||
// from packets of other protocols when STUN is multiplexed with those
|
||||
// other protocols on the same port.
|
||||
//
|
||||
// The magic cookie field MUST contain the fixed value 0x2112A442 in
|
||||
// network byte order.
|
||||
//
|
||||
// Defined in "STUN Message Structure", section 6.
|
||||
magicCookie = 0x2112A442
|
||||
)
|
||||
|
||||
const transactionIDSize = 12 // 96 bit
|
||||
|
||||
// Message represents a single STUN packet. It uses aggressive internal
|
||||
// buffering to enable zero-allocation encoding and decoding,
|
||||
// so there are some usage constraints:
|
||||
//
|
||||
// * Message and its fields is valid only until AcquireMessage call.
|
||||
type Message struct {
|
||||
Type MessageType
|
||||
Length uint32 // len(Raw) not including header
|
||||
TransactionID [transactionIDSize]byte
|
||||
Attributes Attributes
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
// CopyTo copies all m to c.
|
||||
func (m Message) CopyTo(c *Message) {
|
||||
c.Type = m.Type
|
||||
c.Length = m.Length
|
||||
copy(c.TransactionID[:], m.TransactionID[:])
|
||||
buf := m.Raw[:int(m.Length)+messageHeaderSize]
|
||||
c.Raw = c.Raw[:0]
|
||||
c.Raw = append(c.Raw, buf...)
|
||||
buf = c.Raw[messageHeaderSize:]
|
||||
for _, a := range m.Attributes {
|
||||
buf = buf[attributeHeaderSize:]
|
||||
c.Attributes = append(c.Attributes, RawAttribute{
|
||||
Length: a.Length,
|
||||
Type: a.Type,
|
||||
Value: buf[:int(a.Length)],
|
||||
})
|
||||
buf = buf[int(a.Length):]
|
||||
}
|
||||
}
|
||||
|
||||
func (m Message) String() string {
|
||||
return fmt.Sprintf("%s (l=%d,%d/%d) attr[%d] id[%s]",
|
||||
m.Type,
|
||||
m.Length,
|
||||
len(m.Raw),
|
||||
cap(m.Raw),
|
||||
len(m.Attributes),
|
||||
base64.StdEncoding.EncodeToString(m.TransactionID[:]),
|
||||
)
|
||||
}
|
||||
|
||||
// NewTransactionID returns new random transaction ID using crypto/rand
|
||||
// as source.
|
||||
func NewTransactionID() (b [transactionIDSize]byte) {
|
||||
_, err := rand.Read(b[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// defaults for pool.
|
||||
const (
|
||||
defaultMessageBufferCapacity = 120
|
||||
)
|
||||
|
||||
// New returns *Message with allocated Raw.
|
||||
func New() *Message {
|
||||
return &Message{
|
||||
Raw: make([]byte, messageHeaderSize, defaultMessageBufferCapacity),
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets Message length, attributes and underlying buffer, as well as
|
||||
// setting readOnly flag to false.
|
||||
func (m *Message) Reset() {
|
||||
m.Raw = m.Raw[:0]
|
||||
m.Length = 0
|
||||
m.Attributes = m.Attributes[:0]
|
||||
}
|
||||
|
||||
// grow ensures that internal buffer will fit v more bytes and
|
||||
// increases it capacity if necessary.
|
||||
func (m *Message) grow(v int) {
|
||||
// Not performing any optimizations here
|
||||
// (e.g. preallocate len(buf) * 2 to reduce allocations)
|
||||
// because they are already done by []byte implementation.
|
||||
n := len(m.Raw) + v
|
||||
for cap(m.Raw) < n {
|
||||
m.Raw = append(m.Raw, 0)
|
||||
}
|
||||
m.Raw = m.Raw[:n]
|
||||
}
|
||||
|
||||
// Add adds AttrEncoder to message, calling Encode method.
|
||||
func (m *Message) Add(a AttrEncoder) error {
|
||||
var (
|
||||
err error
|
||||
t AttrType
|
||||
)
|
||||
initial := len(m.Raw)
|
||||
for i := 0; i < attributeHeaderSize; i++ {
|
||||
m.Raw = append(m.Raw, 0)
|
||||
}
|
||||
start := len(m.Raw)
|
||||
t, m.Raw, err = a.Encode(m.Raw, m)
|
||||
if err != nil {
|
||||
m.Raw = m.Raw[:initial]
|
||||
return err
|
||||
}
|
||||
m.AddRaw(t, m.Raw[start:])
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRaw appends new attribute to message. Not goroutine-safe.
|
||||
//
|
||||
// Value of attribute is copied to internal buffer so
|
||||
// it is safe to reuse v.
|
||||
func (m *Message) AddRaw(t AttrType, v []byte) {
|
||||
// CPU: suboptimal;
|
||||
// allocating memory for TLV (type-length-value), where
|
||||
// type-length is attribute header.
|
||||
// m.buf.B[0:20] is reserved by header
|
||||
// internal buffer will look like:
|
||||
// [0:20] <- message header
|
||||
// [20:20+m.Length] <- added message attributes
|
||||
// [20+m.Length:20+m.Length+len(v) + 4] <- allocated slice for new TLV
|
||||
// [first:last] <- same as previous
|
||||
// [0 1|2 3|4 4 + len(v)]
|
||||
// T L V
|
||||
allocSize := attributeHeaderSize + len(v) // len(TLV)
|
||||
first := messageHeaderSize + int(m.Length) // first byte number
|
||||
last := first + allocSize // last byte number
|
||||
m.grow(last) // growing cap(b) to fit TLV
|
||||
m.Raw = m.Raw[:last] // now len(b) = last
|
||||
m.Length += uint32(allocSize) // rendering length change
|
||||
// subslicing internal buffer to simplify encoding
|
||||
buf := m.Raw[first:last] // slice for TLV
|
||||
value := buf[attributeHeaderSize:] // slice for value
|
||||
attr := RawAttribute{
|
||||
Type: t,
|
||||
Value: value,
|
||||
Length: uint16(len(v)),
|
||||
}
|
||||
// encoding attribute TLV to internal buffer
|
||||
bin.PutUint16(buf[0:2], attr.Type.Value()) // type
|
||||
bin.PutUint16(buf[2:4], attr.Length) // length
|
||||
copy(value, v) // value
|
||||
if attr.Length%padding != 0 {
|
||||
// performing padding
|
||||
bytesToAdd := nearestLength(len(v)) - len(v)
|
||||
last += bytesToAdd
|
||||
m.grow(last)
|
||||
// setting all padding bytes to zero
|
||||
// to prevent data leak from previous
|
||||
// data in next bytesToAdd bytes
|
||||
buf = m.Raw[last-bytesToAdd : last]
|
||||
for i := range buf {
|
||||
buf[i] = 0
|
||||
}
|
||||
m.Raw = m.Raw[:last] // increasing buffer length
|
||||
m.Length += uint32(bytesToAdd) // rendering length change
|
||||
}
|
||||
m.Attributes = append(m.Attributes, attr)
|
||||
}
|
||||
|
||||
const (
|
||||
padding = 4
|
||||
)
|
||||
|
||||
func nearestLength(l int) int {
|
||||
n := padding * (l / padding)
|
||||
if n < l {
|
||||
n += padding
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Equal returns true if Message b equals to m.
|
||||
func (m *Message) Equal(b *Message) bool {
|
||||
if m.Type != b.Type {
|
||||
return false
|
||||
}
|
||||
if m.TransactionID != b.TransactionID {
|
||||
return false
|
||||
}
|
||||
if m.Length != b.Length {
|
||||
return false
|
||||
}
|
||||
for _, a := range m.Attributes {
|
||||
aB, ok := b.Attributes.Get(a.Type)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !aB.Equal(a) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// WriteLength writes m.Length to m.Raw.
|
||||
func (m *Message) WriteLength() { bin.PutUint16(m.Raw[2:4], uint16(m.Length)) }
|
||||
|
||||
// WriteHeader writes header to underlying buffer. Not goroutine-safe.
|
||||
func (m *Message) WriteHeader() {
|
||||
// encoding header
|
||||
if len(m.Raw) < messageHeaderSize {
|
||||
m.grow(messageHeaderSize)
|
||||
}
|
||||
bin.PutUint16(m.Raw[0:2], m.Type.Value())
|
||||
bin.PutUint32(m.Raw[4:8], magicCookie)
|
||||
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:])
|
||||
// attributes are already encoded
|
||||
// writing length as size, in bytes, not including the 20-byte STUN header.
|
||||
bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize))
|
||||
}
|
||||
|
||||
// WriteAttributes encodes all m.Attributes to m.
|
||||
func (m *Message) WriteAttributes() {
|
||||
for _, a := range m.Attributes {
|
||||
m.AddRaw(a.Type, a.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode writes m into Raw.
|
||||
func (m *Message) Encode() {
|
||||
m.Raw = m.Raw[:0]
|
||||
m.WriteHeader()
|
||||
m.WriteAttributes()
|
||||
}
|
||||
|
||||
// WriteTo implements WriterTo.
|
||||
func (m *Message) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(m.Raw)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// WriteToConn writes a packet with message to addr, using c.
|
||||
// Deprecated; non-idiomatic.
|
||||
func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) {
|
||||
return c.WriteTo(m.Raw, addr)
|
||||
}
|
||||
|
||||
// ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
|
||||
// Decodes it and return error if any. If m.Raw is too small, will return
|
||||
// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
|
||||
//
|
||||
// Can return *DecodeErr while decoding.
|
||||
func (m *Message) ReadFrom(r io.Reader) (int64, error) {
|
||||
tBuf := m.Raw[:cap(m.Raw)]
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
if n, err = r.Read(tBuf); err != nil {
|
||||
return int64(n), err
|
||||
}
|
||||
m.Raw = tBuf[:n]
|
||||
return int64(n), m.Decode()
|
||||
}
|
||||
|
||||
func newAttrDecodeErr(children, message string) *DecodeErr {
|
||||
return newDecodeErr("attribute", children, message)
|
||||
}
|
||||
|
||||
// IsMessage returns true if b looks like STUN message.
|
||||
// Useful for multiplexing. IsMessage does not guarantee
|
||||
// that decoding will be successful.
|
||||
func IsMessage(b []byte) bool {
|
||||
// b should be at least messageHeaderSize bytes and
|
||||
// contain correct magicCookie.
|
||||
return len(b) >= messageHeaderSize &&
|
||||
binary.BigEndian.Uint32(b[4:8]) == magicCookie
|
||||
}
|
||||
|
||||
const (
|
||||
// ErrUnexpectedHeaderEOF means that there were not enough bytes in
|
||||
// m.Raw to read header.
|
||||
ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header"
|
||||
)
|
||||
|
||||
// Decode decodes m.Raw into m.
|
||||
func (m *Message) Decode() error {
|
||||
// decoding message header
|
||||
buf := m.Raw
|
||||
if len(buf) < messageHeaderSize {
|
||||
return ErrUnexpectedHeaderEOF
|
||||
}
|
||||
var (
|
||||
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes
|
||||
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes
|
||||
cookie = binary.BigEndian.Uint32(buf[4:8])
|
||||
fullSize = messageHeaderSize + size
|
||||
)
|
||||
if cookie != magicCookie {
|
||||
msg := fmt.Sprintf(
|
||||
"%x is invalid magic cookie (should be %x)",
|
||||
cookie, magicCookie,
|
||||
)
|
||||
return newDecodeErr("message", "cookie", msg)
|
||||
}
|
||||
if len(buf) < fullSize {
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected message size)",
|
||||
len(buf), fullSize,
|
||||
)
|
||||
return newAttrDecodeErr("message", msg)
|
||||
}
|
||||
// saving header data
|
||||
m.Type.ReadValue(t)
|
||||
m.Length = uint32(size)
|
||||
copy(m.TransactionID[:], buf[8:messageHeaderSize])
|
||||
|
||||
var (
|
||||
offset = 0
|
||||
b = buf[messageHeaderSize:fullSize]
|
||||
)
|
||||
for offset < size {
|
||||
// checking that we have enough bytes to read header
|
||||
if len(b) < attributeHeaderSize {
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected header size)",
|
||||
len(b), attributeHeaderSize,
|
||||
)
|
||||
return newAttrDecodeErr("header", msg)
|
||||
}
|
||||
var (
|
||||
a = RawAttribute{
|
||||
Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes
|
||||
Length: bin.Uint16(b[2:4]), // second 2 bytes
|
||||
}
|
||||
aL = int(a.Length) // attribute length
|
||||
aBuffL = nearestLength(aL) // expected buffer length (with padding)
|
||||
)
|
||||
b = b[attributeHeaderSize:] // slicing again to simplify value read
|
||||
offset += attributeHeaderSize
|
||||
if len(b) < aBuffL { // checking size
|
||||
msg := fmt.Sprintf(
|
||||
"buffer length %d is less than %d (expected value size)",
|
||||
len(b), aBuffL,
|
||||
)
|
||||
return newAttrDecodeErr("value", msg)
|
||||
}
|
||||
a.Value = b[:aL]
|
||||
offset += aBuffL
|
||||
b = b[aBuffL:]
|
||||
|
||||
m.Attributes = append(m.Attributes, a)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadBytes decodes message and return error if any.
|
||||
//
|
||||
// Any error is unrecoverable, but message could be partially decoded.
|
||||
// Deprecated: use m.Decode.
|
||||
func (m *Message) ReadBytes(tBuf []byte) (int, error) {
|
||||
m.Raw = append(m.Raw[:0], tBuf...)
|
||||
return len(tBuf), m.Decode()
|
||||
}
|
||||
|
||||
const (
|
||||
attributeHeaderSize = 4
|
||||
messageHeaderSize = 20
|
||||
)
|
||||
|
||||
// MaxPacketSize is maximum size of UDP packet that is processable in
|
||||
// this package for STUN message.
|
||||
const MaxPacketSize = 2048
|
||||
|
||||
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
|
||||
type MessageClass byte
|
||||
|
||||
// Possible values for message class in STUN Message Type.
|
||||
const (
|
||||
ClassRequest MessageClass = 0x00 // 0b00
|
||||
ClassIndication MessageClass = 0x01 // 0b01
|
||||
ClassSuccessResponse MessageClass = 0x02 // 0b10
|
||||
ClassErrorResponse MessageClass = 0x03 // 0b11
|
||||
)
|
||||
|
||||
func (c MessageClass) String() string {
|
||||
switch c {
|
||||
case ClassRequest:
|
||||
return "request"
|
||||
case ClassIndication:
|
||||
return "indication"
|
||||
case ClassSuccessResponse:
|
||||
return "success response"
|
||||
case ClassErrorResponse:
|
||||
return "error response"
|
||||
default:
|
||||
panic("unknown message class")
|
||||
}
|
||||
}
|
||||
|
||||
// Method is uint16 representation of 12-bit STUN method.
|
||||
type Method uint16
|
||||
|
||||
// Possible methods for STUN Message.
|
||||
const (
|
||||
MethodBinding Method = 0x001
|
||||
MethodAllocate Method = 0x003
|
||||
MethodRefresh Method = 0x004
|
||||
MethodSend Method = 0x006
|
||||
MethodData Method = 0x007
|
||||
MethodCreatePermission Method = 0x008
|
||||
MethodChannelBind Method = 0x009
|
||||
)
|
||||
|
||||
func (m Method) String() string {
|
||||
switch m {
|
||||
case MethodBinding:
|
||||
return "binding"
|
||||
case MethodAllocate:
|
||||
return "allocate"
|
||||
case MethodRefresh:
|
||||
return "refresh"
|
||||
case MethodSend:
|
||||
return "send"
|
||||
case MethodData:
|
||||
return "data"
|
||||
case MethodCreatePermission:
|
||||
return "create permission"
|
||||
case MethodChannelBind:
|
||||
return "channel bind"
|
||||
default:
|
||||
return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16))
|
||||
}
|
||||
}
|
||||
|
||||
// MessageType is STUN Message Type Field.
|
||||
type MessageType struct {
|
||||
Class MessageClass
|
||||
Method Method
|
||||
}
|
||||
|
||||
const (
|
||||
methodABits = 0xf // 0b0000000000001111
|
||||
methodBBits = 0x70 // 0b0000000001110000
|
||||
methodDBits = 0xf80 // 0b0000111110000000
|
||||
|
||||
methodBShift = 1
|
||||
methodDShift = 2
|
||||
|
||||
firstBit = 0x1
|
||||
secondBit = 0x2
|
||||
|
||||
c0Bit = firstBit
|
||||
c1Bit = secondBit
|
||||
|
||||
classC0Shift = 4
|
||||
classC1Shift = 7
|
||||
)
|
||||
|
||||
// Value returns bit representation of messageType.
|
||||
func (t MessageType) Value() uint16 {
|
||||
// 0 1
|
||||
// 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|
||||
// |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
|
||||
// +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// Figure 3: Format of STUN Message Type Field
|
||||
|
||||
// warning: Abandon all hope ye who enter here.
|
||||
// splitting M into A(M0-M3), B(M4-M6), D(M7-M11)
|
||||
m := uint16(t.Method)
|
||||
a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
|
||||
b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
|
||||
d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
|
||||
|
||||
// shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit)
|
||||
m = a + (b << methodBShift) + (d << methodDShift)
|
||||
|
||||
// C0 is zero bit of C, C1 is fist bit.
|
||||
// C0 = C * 0b01, C1 = (C * 0b10) >> 1
|
||||
// Ct = C0 << 4 + C1 << 8.
|
||||
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
|
||||
// We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
|
||||
// (see figure 3).
|
||||
c := uint16(t.Class)
|
||||
c0 := (c & c0Bit) << classC0Shift
|
||||
c1 := (c & c1Bit) << classC1Shift
|
||||
class := c0 + c1
|
||||
|
||||
return m + class
|
||||
}
|
||||
|
||||
// ReadValue decodes uint16 into MessageType.
|
||||
func (t *MessageType) ReadValue(v uint16) {
|
||||
// decoding class
|
||||
// we are taking first bit from v >> 4 and second from v >> 7.
|
||||
c0 := (v >> classC0Shift) & c0Bit
|
||||
c1 := (v >> classC1Shift) & c1Bit
|
||||
class := c0 + c1
|
||||
t.Class = MessageClass(class)
|
||||
|
||||
// decoding method
|
||||
a := v & methodABits // A(M0-M3)
|
||||
b := (v >> methodBShift) & methodBBits // B(M4-M6)
|
||||
d := (v >> methodDShift) & methodDBits // D(M7-M11)
|
||||
m := a + b + d
|
||||
t.Method = Method(m)
|
||||
}
|
||||
|
||||
func (t MessageType) String() string {
|
||||
return fmt.Sprintf("%s %s", t.Method, t.Class)
|
||||
}
|
||||
|
119
stun_test.go
119
stun_test.go
@@ -9,14 +9,13 @@ import (
|
||||
"hash/crc64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"net"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -33,24 +32,11 @@ func (m *Message) reader() *bytes.Reader {
|
||||
return bytes.NewReader(m.Raw)
|
||||
}
|
||||
|
||||
func TestMessageCopy(t *testing.T) {
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
mCopy := New()
|
||||
m.CopyTo(mCopy)
|
||||
if !mCopy.Equal(m) {
|
||||
t.Error(mCopy, "!=", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBuffer(t *testing.T) {
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
|
||||
@@ -69,7 +55,7 @@ func BenchmarkMessage_Write(b *testing.B) {
|
||||
transactionID := NewTransactionID()
|
||||
m := New()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.AddRaw(AttrErrorCode, attributeValue)
|
||||
m.Add(AttrErrorCode, attributeValue)
|
||||
m.TransactionID = transactionID
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.WriteHeader()
|
||||
@@ -133,6 +119,25 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_WriteTo(t *testing.T) {
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := m.WriteTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !mDecoded.Equal(m) {
|
||||
t.Error(mDecoded, "!", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_Cookie(t *testing.T) {
|
||||
buf := make([]byte, 20)
|
||||
mDecoded := New()
|
||||
@@ -156,11 +161,11 @@ func TestMessage_BadLength(t *testing.T) {
|
||||
Length: 4,
|
||||
TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
|
||||
}
|
||||
m.AddRaw(0x1, []byte{1, 2})
|
||||
m.Add(0x1, []byte{1, 2})
|
||||
m.WriteHeader()
|
||||
m.Raw[20+3] = 10 // set attr length = 10
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadBytes(m.Raw); err == nil {
|
||||
if _, err := mDecoded.Write(m.Raw); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
}
|
||||
@@ -295,7 +300,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
mRec := New()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := mRec.ReadBytes(m.Raw); err != nil {
|
||||
if _, err := mRec.Write(m.Raw); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
mRec.Reset()
|
||||
@@ -439,9 +444,7 @@ func TestMessage_String(t *testing.T) {
|
||||
|
||||
func TestIsMessage(t *testing.T) {
|
||||
m := New()
|
||||
m.Add(&Software{
|
||||
Raw: []byte("test"),
|
||||
})
|
||||
NewSoftware("software").AddTo(m)
|
||||
m.WriteHeader()
|
||||
|
||||
var tt = [...]struct {
|
||||
@@ -468,7 +471,7 @@ func BenchmarkIsMessage(b *testing.B) {
|
||||
m := New()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.AddSoftware("cydev/stun test")
|
||||
NewSoftware("cydev/stun test").AddTo(m)
|
||||
m.WriteHeader()
|
||||
|
||||
b.SetBytes(int64(messageHeaderSize))
|
||||
@@ -502,13 +505,13 @@ func loadData(tb testing.TB, name string) []byte {
|
||||
func TestExampleChrome(t *testing.T) {
|
||||
buf := loadData(t, "ex1_chrome.stun")
|
||||
m := New()
|
||||
_, err := m.ReadBytes(buf)
|
||||
_, err := m.Write(buf)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse ex1_chrome: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNearestLen(t *testing.T) {
|
||||
func TestPadding(t *testing.T) {
|
||||
tt := []struct {
|
||||
in, out int
|
||||
}{
|
||||
@@ -525,7 +528,7 @@ func TestNearestLen(t *testing.T) {
|
||||
{40, 40}, // 10
|
||||
}
|
||||
for i, c := range tt {
|
||||
if got := nearestLength(c.in); got != c.out {
|
||||
if got := nearestPaddedValueLength(c.in); got != c.out {
|
||||
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
|
||||
i, c.in, got, c.out,
|
||||
)
|
||||
@@ -562,7 +565,7 @@ func TestMessageFromBrowsers(t *testing.T) {
|
||||
if b != crc64.Checksum(data, crcTable) {
|
||||
t.Error("crc64 check failed for ", line[1])
|
||||
}
|
||||
if _, err = m.ReadBytes(data); err != nil {
|
||||
if _, err = m.Write(data); err != nil {
|
||||
t.Error("failed to decode ", line[1], " as message: ", err)
|
||||
}
|
||||
m.Reset()
|
||||
@@ -583,20 +586,29 @@ func BenchmarkNewTransactionID(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessage_NewTransactionID(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := m.NewTransactionID(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessageFull(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
ip: net.IPv4(213, 1, 223, 5),
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
fingerprint := new(Fingerprint)
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Add(addr)
|
||||
m.Add(s)
|
||||
addAttr(b, m, addr)
|
||||
addAttr(b, m, s)
|
||||
m.WriteAttributes()
|
||||
m.WriteHeader()
|
||||
fingerprint.AddTo(m)
|
||||
Fingerprint.AddTo(m)
|
||||
m.WriteHeader()
|
||||
m.Reset()
|
||||
}
|
||||
@@ -607,19 +619,26 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
|
||||
m := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
ip: net.IPv4(213, 1, 223, 5),
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
t, v, _ := addr.Encode(nil, m)
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.AddRaw(AttrSoftware, s.Raw)
|
||||
m.AddRaw(t, v)
|
||||
m.WriteAttributes()
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := s.AddTo(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.WriteHeader()
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func addAttr(t testing.TB, m *Message, a AttrEncoder) {
|
||||
if err := m.Add(a); err != nil {
|
||||
type attributeEncoder interface {
|
||||
AddTo(m *Message) error
|
||||
}
|
||||
|
||||
func addAttr(t testing.TB, m *Message, a attributeEncoder) {
|
||||
if err := a.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
@@ -629,14 +648,13 @@ func BenchmarkFingerprint_AddTo(b *testing.B) {
|
||||
m := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
ip: net.IPv4(213, 1, 223, 5),
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
addAttr(b, m, addr)
|
||||
addAttr(b, m, s)
|
||||
fingerprint := new(Fingerprint)
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
fingerprint.AddTo(m)
|
||||
Fingerprint.AddTo(m)
|
||||
m.WriteLength()
|
||||
m.Length -= attributeHeaderSize + fingerprintSize
|
||||
m.Raw = m.Raw[:m.Length+messageHeaderSize]
|
||||
@@ -648,12 +666,10 @@ func TestFingerprint_Check(t *testing.T) {
|
||||
m := new(Message)
|
||||
addAttr(t, m, NewSoftware("software"))
|
||||
m.WriteHeader()
|
||||
fInit := new(Fingerprint)
|
||||
fInit.AddTo(m)
|
||||
Fingerprint.AddTo(m)
|
||||
m.WriteHeader()
|
||||
f := new(Fingerprint)
|
||||
if err := f.Check(m); err != nil {
|
||||
t.Errorf("decoded: %d, encoded: %d, err: %s", f.Value, fInit.Value, err)
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,17 +678,16 @@ func BenchmarkFingerprint_Check(b *testing.B) {
|
||||
m := new(Message)
|
||||
s := NewSoftware("software")
|
||||
addr := &XORMappedAddress{
|
||||
ip: net.IPv4(213, 1, 223, 5),
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
addAttr(b, m, addr)
|
||||
addAttr(b, m, s)
|
||||
fingerprint := new(Fingerprint)
|
||||
m.WriteHeader()
|
||||
fingerprint.AddTo(m)
|
||||
Fingerprint.AddTo(m)
|
||||
m.WriteHeader()
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := fingerprint.Check(m); err != nil {
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
"github.com/ernado/stun"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -16,6 +18,8 @@ import (
|
||||
var (
|
||||
network = flag.String("net", "udp", "network to listen")
|
||||
address = flag.String("addr", "0.0.0.0:3478", "address to listen")
|
||||
workers = flag.Int("workers", 1, "workers to start")
|
||||
profile = flag.Bool("profile", false, "profile")
|
||||
)
|
||||
|
||||
// Server is RFC 5389 basic server implementation.
|
||||
@@ -80,7 +84,7 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
|
||||
if !stun.IsMessage(b) {
|
||||
return errNotSTUNMessage
|
||||
}
|
||||
if _, err := req.ReadBytes(b); err != nil {
|
||||
if _, err := req.Write(b); err != nil {
|
||||
return errors.Wrap(err, "failed to read message")
|
||||
}
|
||||
res.TransactionID = req.TransactionID
|
||||
@@ -99,8 +103,12 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error {
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown addr: %v", addr))
|
||||
}
|
||||
res.AddXORMappedAddress(ip, port)
|
||||
res.AddRaw(stun.AttrSoftware, software.Raw)
|
||||
xma := &stun.XORMappedAddress{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
}
|
||||
xma.AddTo(res)
|
||||
software.AddTo(res)
|
||||
res.WriteHeader()
|
||||
return nil
|
||||
}
|
||||
@@ -116,8 +124,8 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error {
|
||||
return nil
|
||||
}
|
||||
// s.logger().Printf("read %d bytes from %s", n, addr)
|
||||
if _, err = req.ReadBytes(buf[:n]); err != nil {
|
||||
s.logger().Printf("ReadBytes: %v", err)
|
||||
if _, err = req.Write(buf[:n]); err != nil {
|
||||
s.logger().Printf("Write: %v", err)
|
||||
return err
|
||||
}
|
||||
if err = basicProcess(addr, buf[:n], req, res); err != nil {
|
||||
@@ -156,8 +164,10 @@ func ListenUDPAndServe(serverNet, laddr string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := &Server{}
|
||||
return s.Serve(c)
|
||||
for i := 1; i < *workers; i++ {
|
||||
go new(Server).Serve(c)
|
||||
}
|
||||
return new(Server).Serve(c)
|
||||
}
|
||||
|
||||
func normalize(address string) string {
|
||||
@@ -172,7 +182,11 @@ func normalize(address string) string {
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
runtime.GOMAXPROCS(1)
|
||||
if *profile {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
}
|
||||
switch *network {
|
||||
case "udp":
|
||||
normalized := normalize(*address)
|
||||
|
@@ -17,7 +17,7 @@ func BenchmarkBasicProcess(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.TransactionID = stun.NewTransactionID()
|
||||
m.Add(stun.NewSoftware("some software"))
|
||||
stun.NewSoftware("some software").AddTo(m)
|
||||
m.WriteHeader()
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
Reference in New Issue
Block a user