diff --git a/LICENSE b/LICENSE index e3d1a3a..e405cb6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2016 Aleksandr Razumov, Cydev. All Rigths Reserved. +Copyright (c) 2016-2017 Aleksandr Razumov, Cydev. All Rigths Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/attribute_errorcode.go b/attribute_errorcode.go index d3f0a29..047b497 100644 --- a/attribute_errorcode.go +++ b/attribute_errorcode.go @@ -2,39 +2,88 @@ package stun // ErrorCodeAttribute represents ERROR-CODE attribute. type ErrorCodeAttribute struct { - Code int + Code ErrorCode Reason []byte } +// constants for ERROR-CODE encoding. +const ( + errorCodeReasonStart = 4 + errorCodeClassByte = 2 + errorCodeNumberByte = 3 + errorCodeReasonMaxB = 763 + errorCodeModulo = 100 +) + +// AddTo adds ERROR-CODE to m. +func (c *ErrorCodeAttribute) AddTo(m *Message) error { + value := make([]byte, + errorCodeReasonStart, errorCodeReasonMaxB, + ) + number := byte(c.Code % errorCodeModulo) // error code modulo 100 + class := byte(c.Code / errorCodeModulo) // hundred digit + value[errorCodeClassByte] = class + value[errorCodeNumberByte] = number + value = append(value, c.Reason...) + m.Add(AttrErrorCode, value) + return nil +} + +// GetFrom decodes ERROR-CODE from m. +func (c *ErrorCodeAttribute) GetFrom(m *Message) error { + v, err := m.Get(AttrErrorCode) + if err != nil { + return err + } + var ( + class = uint16(v[errorCodeClassByte]) + number = uint16(v[errorCodeNumberByte]) + code = int(class*errorCodeModulo + number) + reason = v[errorCodeReasonStart:] + ) + c.Code = ErrorCode(code) + c.Reason = reason + return nil +} + // ErrorCode is code for ERROR-CODE attribute. type ErrorCode int +// ErrNoDefaultReason means that default reason for provided error code +// is not defined in RFC. +const ErrNoDefaultReason Error = "No default reason for ErrorCode" + +// AddTo adds ERROR-CODE with default reason to m. If there +// is no default reason, returns ErrNoDefaultReason. +func (c ErrorCode) AddTo(m *Message) error { + reason := errorReasons[c] + if reason == nil { + return ErrNoDefaultReason + } + a := &ErrorCodeAttribute{ + Code: c, + Reason: reason, + } + return a.AddTo(m) +} + // Possible error codes. const ( - CodeTryAlternate = 300 - CodeBadRequest = 400 - CodeUnauthorised = 401 - CodeUnknownAttribute = 420 - CodeStaleNonce = 428 - CodeRoleConflict = 478 - CodeServerError = 500 + CodeTryAlternate ErrorCode = 300 + CodeBadRequest ErrorCode = 400 + CodeUnauthorised ErrorCode = 401 + CodeUnknownAttribute ErrorCode = 420 + CodeStaleNonce ErrorCode = 428 + CodeRoleConflict ErrorCode = 478 + CodeServerError ErrorCode = 500 ) -var errorReasons = map[int]string{ - CodeTryAlternate: "Try Alternate", - CodeBadRequest: "Bad Request", - CodeUnauthorised: "Unauthorised", - CodeUnknownAttribute: "Unknown Attribute", - CodeStaleNonce: "Stale Nonce", - CodeServerError: "Server Error", - CodeRoleConflict: "Role Conflict", -} - -// Reason returns recommended reason string. -func (c ErrorCode) Reason() string { - reason, ok := errorReasons[int(c)] - if !ok { - return "Unknown Error" - } - return reason +var errorReasons = map[ErrorCode][]byte{ + CodeTryAlternate: []byte("Try Alternate"), + CodeBadRequest: []byte("Bad Request"), + CodeUnauthorised: []byte("Unauthorised"), + CodeUnknownAttribute: []byte("Unknown Attribute"), + CodeStaleNonce: []byte("Stale Nonce"), + CodeServerError: []byte("Server Error"), + CodeRoleConflict: []byte("Role Conflict"), } diff --git a/attribute_errorcode_test.go b/attribute_errorcode_test.go deleted file mode 100644 index f42697d..0000000 --- a/attribute_errorcode_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package stun - -import "testing" - -func TestErrorCode_Reason(t *testing.T) { - codes := [...]ErrorCode{ - CodeTryAlternate, - CodeBadRequest, - CodeUnauthorised, - CodeUnknownAttribute, - CodeStaleNonce, - CodeRoleConflict, - CodeServerError, - } - for _, code := range codes { - if code.Reason() == "Unknown Error" { - t.Error(code, "should not be unknown") - } - if len(code.Reason()) == 0 { - t.Error(code, "should not be blank") - } - } - if ErrorCode(999).Reason() != "Unknown Error" { - t.Error("999 error should be Unknown") - } -} diff --git a/attribute_fingerprint.go b/attribute_fingerprint.go new file mode 100644 index 0000000..73e3ec5 --- /dev/null +++ b/attribute_fingerprint.go @@ -0,0 +1,68 @@ +package stun + +import ( + "fmt" + "hash/crc32" +) + +// FingerprintAttr represent FINGERPRINT attribute. +type FingerprintAttr struct{} + +// CRCMismatch represents CRC check error. +type CRCMismatch struct { + Expected uint32 + Actual uint32 +} + +func (m CRCMismatch) Error() string { + return fmt.Sprintf("CRC mismatch: %x (expected) != %x (actual)", + m.Expected, + m.Actual, + ) +} + +// Fingerprint is shorthand for FingerprintAttr. +var Fingerprint = &FingerprintAttr{} + +const ( + fingerprintXORValue uint32 = 0x5354554e + fingerprintSize = 4 // 32 bit +) + +// FingerprintValue returns CRC32 of m XOR-ed by 0x5354554e. +func FingerprintValue(b []byte) uint32 { + return crc32.ChecksumIEEE(b) ^ fingerprintXORValue // XOR +} + +// AddTo adds fingerprint to message. +func (FingerprintAttr) AddTo(m *Message) error { + l := m.Length + // length in header should include size of fingerprint attribute + m.Length += fingerprintSize + attributeHeaderSize // increasing length + m.WriteLength() // writing Length to Raw + b := make([]byte, fingerprintSize) + v := FingerprintValue(m.Raw) + bin.PutUint32(b, v) + m.Length = l + m.Add(AttrFingerprint, b) + return nil +} + +// Check reads fingerprint value from m and checks it, returning error if any. +// Can return *DecodeErr, ErrAttributeNotFound and *CRCMismatch. +func (FingerprintAttr) Check(m *Message) error { + v, err := m.Get(AttrFingerprint) + if err != nil { + return err + } + if len(v) != fingerprintSize { + return newDecodeErr("message", "fingerprint", "bad length") + } + val := bin.Uint32(v) + attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) + expected := FingerprintValue(m.Raw[:attrStart]) + if expected != val { + return &CRCMismatch{Expected: expected, Actual: val} + } + return nil +} diff --git a/attribute_software.go b/attribute_software.go new file mode 100644 index 0000000..d25c2f0 --- /dev/null +++ b/attribute_software.go @@ -0,0 +1,31 @@ +package stun + +// Software is SOFTWARE attribute. +type Software struct { + Raw []byte +} + +func (s *Software) String() string { + return string(s.Raw) +} + +// NewSoftware returns *Software from string. +func NewSoftware(software string) *Software { + return &Software{Raw: []byte(software)} +} + +// AddTo adds Software attribute to m. +func (s *Software) AddTo(m *Message) error { + m.Add(AttrSoftware, m.Raw) + return nil +} + +// GetFrom decodes Software from m. +func (s *Software) GetFrom(m *Message) error { + v, err := m.Get(AttrSoftware) + if err != nil { + return err + } + s.Raw = v + return nil +} diff --git a/attribute_xoraddr.go b/attribute_xoraddr.go new file mode 100644 index 0000000..3bca3dc --- /dev/null +++ b/attribute_xoraddr.go @@ -0,0 +1,118 @@ +package stun + +import ( + "fmt" + "net" +) + +const ( + familyIPv4 byte = 0x01 + familyIPv6 byte = 0x02 +) + +// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. +type XORMappedAddress struct { + IP net.IP + Port int +} + +func (a XORMappedAddress) String() string { + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + +// Is p all zeros? +func isZeros(p net.IP) bool { + for i := 0; i < len(p); i++ { + if p[i] != 0 { + return false + } + } + return true +} + +// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}. +const ErrBadIPLength Error = "invalid length if IP value" + +// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength +// if len(a.IP) is invalid. +func (a *XORMappedAddress) AddTo(m *Message) error { + var ( + family = familyIPv4 + ip = a.IP + ) + if len(a.IP) == net.IPv6len { + // Optimized for performance. See net.IP.To4 method. + if isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { + ip = ip[12:16] // like in ip.To4() + } else { + family = familyIPv6 + } + } else if len(ip) != net.IPv4len { + return ErrBadIPLength + } + value := make([]byte, 32+128) + value[0] = 0 // first 8 bits are zeroes + xorValue := make([]byte, net.IPv6len) + copy(xorValue[4:], m.TransactionID[:]) + bin.PutUint32(xorValue[0:4], magicCookie) + bin.PutUint16(value[0:2], uint16(family)) + bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) + xorBytes(value[4:4+len(ip)], ip, xorValue) + m.Add(AttrXORMappedAddress, value[:4+len(ip)]) + return nil +} + +// GetFrom decodes XOR-MAPPED-ADDRESS attribute in message and returns +// error if any. While decoding, a.IP is reused if possible and can be +// rendered to invalid state (e.g. if a.IP was set to IPv6 and then +// IPv4 value were decoded into it), be careful. +// +// Example: +// +// expectedIP := net.ParseIP("213.141.156.236") +// expectedIP.String() // 213.141.156.236, 16 bytes, first 12 of them are zeroes +// expectedPort := 21254 +// addr := &XORMappedAddress{ +// IP: expectedIP, +// Port: expectedPort, +// } +// // addr were added to message that is decoded as newMessage +// // ... +// +// addr.GetFrom(newMessage) +// addr.IP.String() // 213.141.156.236, net.IPv4Len +// expectedIP.String() // d58d:9cec::ffff:d58d:9cec, 16 bytes, first 4 are IPv4 +// // now we have len(expectedIP) = 16 and len(addr.IP) = 4. +func (a *XORMappedAddress) GetFrom(m *Message) error { + v, err := m.Get(AttrXORMappedAddress) + if err != nil { + return err + } + family := byte(bin.Uint16(v[0:2])) + if family != familyIPv6 && family != familyIPv4 { + return newDecodeErr("xor-mapped address", "family", + fmt.Sprintf("bad value %d", family), + ) + } + ipLen := net.IPv4len + if family == familyIPv6 { + ipLen = net.IPv6len + } + // Ensuring len(a.IP) == ipLen and reusing a.IP. + if len(a.IP) < ipLen { + a.IP = a.IP[:cap(a.IP)] + for len(a.IP) < ipLen { + a.IP = append(a.IP, 0) + } + } + a.IP = a.IP[:ipLen] + for i := range a.IP { + a.IP[i] = 0 + } + a.Port = int(bin.Uint16(v[2:4])) ^ (magicCookie >> 16) + xorValue := make([]byte, 4+transactionIDSize) + bin.PutUint32(xorValue[0:4], magicCookie) + copy(xorValue[4:], m.TransactionID[:]) + xorBytes(a.IP, v[4:], xorValue) + return nil +} diff --git a/attributes.go b/attributes.go index 3f2a28d..6ea1d33 100644 --- a/attributes.go +++ b/attributes.go @@ -1,23 +1,10 @@ package stun import ( - "encoding/binary" "fmt" - "hash/crc32" - "net" "strconv" ) -// AttrWriter wraps AddRaw method. -type AttrWriter interface { - AddRaw(t AttrType, v []byte) -} - -// AttrEncoder wraps Encode method. -type AttrEncoder interface { - Encode(b []byte, m *Message) (AttrType, []byte, error) -} - // Attributes is list of message attributes. type Attributes []RawAttribute @@ -77,7 +64,7 @@ const ( AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN ) -// Attributes from An Origin RawAttribute for the STUN Protocol. +// Attributes from An Origin Attribute for the STUN Protocol. const ( AttrOrigin AttrType = 0x802F ) @@ -118,34 +105,30 @@ var attrNames = map[AttrType]string{ func (t AttrType) String() string { s, ok := attrNames[t] if !ok { - // just return hex representation of unknown attribute type + // Just return hex representation of unknown attribute type. return "0x" + strconv.FormatUint(uint64(t), 16) } return s } // RawAttribute is a Type-Length-Value (TLV) object that -// can be added to a STUN message. Attributes are divided into two +// can be added to a STUN message. Attributes are divided into two // types: comprehension-required and comprehension-optional. STUN // agents can safely ignore comprehension-optional attributes they // don't understand, but cannot successfully process a message if it // contains comprehension-required attributes that are not // understood. +// +// TODO(ar): Decide to use pointer or non-pointer RawAttribute. type RawAttribute struct { Type AttrType Length uint16 // ignored while encoding Value []byte } -// Encode implements AttrEncoder. -func (a *RawAttribute) Encode(m *Message) ([]byte, error) { - return m.Raw, nil -} - -// Decode implements AttrDecoder. -func (a *RawAttribute) Decode(v []byte, m *Message) error { - a.Value = v - a.Length = uint16(len(v)) +// AddTo adds RawAttribute to m. +func (a *RawAttribute) AddTo(m *Message) error { + m.Add(a.Type, m.Raw) return nil } @@ -172,319 +155,17 @@ func (a RawAttribute) String() string { return fmt.Sprintf("%s: %x", a.Type, a.Value) } -// getAttrValue returns byte slice that represents attribute value, -// if there is no value attribute with shuck type, +// ErrAttributeNotFound means that attribute with provided attribute +// type does not exist in message. +const ErrAttributeNotFound Error = "Attribute not found" + +// Get returns byte slice that represents attribute value, +// if there is no attribute with such type, // ErrAttributeNotFound is returned. -func (m *Message) getAttrValue(t AttrType) ([]byte, error) { +func (m *Message) Get(t AttrType) ([]byte, error) { v, ok := m.Attributes.Get(t) if !ok { return nil, ErrAttributeNotFound } return v.Value, nil } - -// AddSoftware adds SOFTWARE attribute with value from string. -// Deprecated: use AddRaw. -func (m *Message) AddSoftware(software string) { - m.AddRaw(AttrSoftware, []byte(software)) -} - -// Set sets the value of attribute if it presents. -func (m *Message) Set(a AttrEncoder) error { - var ( - v []byte - err error - t AttrType - ) - t, v, err = a.Encode(v, m) - if err != nil { - return err - } - buf, err := m.getAttrValue(t) - if err != nil { - return err - } - if len(v) != len(buf) { - return ErrBadSetLength - } - copy(buf, v) - return nil -} - -// GetSoftwareBytes returns SOFTWARE attribute value in byte slice. -// If not found, returns nil. -func (m *Message) GetSoftwareBytes() []byte { - v, ok := m.Attributes.Get(AttrSoftware) - if !ok { - return nil - } - return v.Value -} - -// GetSoftware returns SOFTWARE attribute value in string. -// If not found, returns blank string. -// Deprecated. -func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) } - -// Address family values. -const ( - FamilyIPv4 byte = 0x01 - FamilyIPv6 byte = 0x02 -) - -// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. -type XORMappedAddress struct { - ip net.IP - port int -} - -// Encode implements AttrEncoder. -func (a *XORMappedAddress) Encode(buf []byte, m *Message) (AttrType, []byte, error) { - // X-Port is computed by taking the mapped port in host byte order, - // XOR’ing it with the most significant 16 bits of the magic cookie, and - // then the converting the result to network byte order. - family := FamilyIPv6 - ip := a.ip - port := a.port - if ipV4 := ip.To4(); ipV4 != nil { - ip = ipV4 - family = FamilyIPv4 - } - value := make([]byte, 32+128) - value[0] = 0 // first 8 bits are zeroes - xorValue := make([]byte, net.IPv6len) - copy(xorValue[4:], m.TransactionID[:]) - binary.BigEndian.PutUint32(xorValue[0:4], magicCookie) - port ^= magicCookie >> 16 - binary.BigEndian.PutUint16(value[0:2], uint16(family)) - binary.BigEndian.PutUint16(value[2:4], uint16(port)) - xorBytes(value[4:4+len(ip)], ip, xorValue) - buf = append(buf, value...) - return AttrXORMappedAddress, buf, nil -} - -// Decode implements AttrDecoder. -// TODO(ar): fix signature. -func (a *XORMappedAddress) Decode(v []byte, m *Message) error { - // X-Port is computed by taking the mapped port in host byte order, - // XOR’ing it with the most significant 16 bits of the magic cookie, and - // then the converting the result to network byte order. - v, err := m.getAttrValue(AttrXORMappedAddress) - if err != nil { - return err - } - family := byte(binary.BigEndian.Uint16(v[0:2])) - if family != FamilyIPv6 && family != FamilyIPv4 { - return newDecodeErr("xor-mapped address", "family", - fmt.Sprintf("bad value %d", family), - ) - } - ipLen := net.IPv4len - if family == FamilyIPv6 { - ipLen = net.IPv6len - } - ip := net.IP(m.allocBuffer(ipLen)) - a.port = int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16) - xorValue := make([]byte, 128) - binary.BigEndian.PutUint32(xorValue[0:4], magicCookie) - copy(xorValue[4:], m.TransactionID[:]) - xorBytes(ip, v[4:], xorValue) - a.ip = ip - return nil -} - -// AddXORMappedAddress adds XOR MAPPED ADDRESS attribute to message. -// Deprecated: use AddRaw. -func (m *Message) AddXORMappedAddress(ip net.IP, port int) { - // X-Port is computed by taking the mapped port in host byte order, - // XOR’ing it with the most significant 16 bits of the magic cookie, and - // then the converting the result to network byte order. - family := FamilyIPv6 - if ipV4 := ip.To4(); ipV4 != nil { - ip = ipV4 - family = FamilyIPv4 - } - value := make([]byte, 32+128) - value[0] = 0 // first 8 bits are zeroes - xorValue := make([]byte, net.IPv6len) - copy(xorValue[4:], m.TransactionID[:]) - binary.BigEndian.PutUint32(xorValue[0:4], magicCookie) - port ^= magicCookie >> 16 - binary.BigEndian.PutUint16(value[0:2], uint16(family)) - binary.BigEndian.PutUint16(value[2:4], uint16(port)) - xorBytes(value[4:4+len(ip)], ip, xorValue) - m.AddRaw(AttrXORMappedAddress, value[:4+len(ip)]) -} - -func (m *Message) allocBuffer(size int) []byte { - capacity := len(m.Raw) + size - m.grow(capacity) - m.Raw = m.Raw[:capacity] - return m.Raw[len(m.Raw)-size:] -} - -// GetXORMappedAddress returns ip, port from attribute and error if any. -// Value for ip is valid until Message is released or underlying buffer is -// corrupted. Returns *DecodeError or ErrAttributeNotFound. -// Deprecated: use GetRaw. -func (m *Message) GetXORMappedAddress() (net.IP, int, error) { - // X-Port is computed by taking the mapped port in host byte order, - // XOR’ing it with the most significant 16 bits of the magic cookie, and - // then the converting the result to network byte order. - v, err := m.getAttrValue(AttrXORMappedAddress) - if err != nil { - return nil, 0, err - } - family := byte(binary.BigEndian.Uint16(v[0:2])) - if family != FamilyIPv6 && family != FamilyIPv4 { - return nil, 0, newDecodeErr("xor-mapped address", "family", - fmt.Sprintf("bad value %d", family), - ) - } - ipLen := net.IPv4len - if family == FamilyIPv6 { - ipLen = net.IPv6len - } - ip := net.IP(m.allocBuffer(ipLen)) - port := int(binary.BigEndian.Uint16(v[2:4])) ^ (magicCookie >> 16) - xorValue := make([]byte, 128) - binary.BigEndian.PutUint32(xorValue[0:4], magicCookie) - copy(xorValue[4:], m.TransactionID[:]) - xorBytes(ip, v[4:], xorValue) - return ip, port, nil -} - -// constants for ERROR-CODE encoding. -const ( - errorCodeReasonStart = 4 - errorCodeClassByte = 2 - errorCodeNumberByte = 3 - errorCodeReasonMaxB = 763 - errorCodeModulo = 100 -) - -// AddErrorCode adds ERROR-CODE attribute to message. -// -// The reason phrase MUST be a UTF-8 [RFC 3629] encoded -// sequence of less than 128 characters (which can be as long as 763 -// bytes). -// Deprecated: use AddRaw. -func (m *Message) AddErrorCode(code int, reason string) { - value := make([]byte, - errorCodeReasonStart, errorCodeReasonMaxB+errorCodeReasonStart, - ) - number := byte(code % errorCodeModulo) // error code modulo 100 - class := byte(code / errorCodeModulo) // hundred digit - value[errorCodeClassByte] = class - value[errorCodeNumberByte] = number - value = append(value, reason...) - m.AddRaw(AttrErrorCode, value) -} - -// AddErrorCodeDefault is wrapper for AddErrorCode that uses recommended -// reason string from RFC. If error code is unknown, reason will be "Unknown -// Error". -// Deprecated: use AddRaw. -func (m *Message) AddErrorCodeDefault(code int) { - m.AddErrorCode(code, ErrorCode(code).Reason()) -} - -// GetErrorCode returns ERROR-CODE code, reason and decode error if any. -// Deprecated: use GetRaw. -func (m *Message) GetErrorCode() (int, []byte, error) { - v, err := m.getAttrValue(AttrErrorCode) - if err != nil { - return 0, nil, err - } - var ( - class = uint16(v[errorCodeClassByte]) - number = uint16(v[errorCodeNumberByte]) - code = int(class*errorCodeModulo + number) - reason = v[errorCodeReasonStart:] - ) - return code, reason, nil -} - -const ( - // ErrAttributeNotFound means that there is no such attribute. - ErrAttributeNotFound Error = "Attribute not found" - - // ErrBadSetLength means that previous attribute value length differs from - // new value. - ErrBadSetLength Error = "Previous attribute length is different" -) - -// Software is SOFTWARE attribute. -type Software struct { - Raw []byte -} - -func (s Software) String() string { - return string(s.Raw) -} - -// NewSoftware returns *Software from string. -func NewSoftware(software string) *Software { - return &Software{Raw: []byte(software)} -} - -// Encode implements AttrEncoder. -func (s *Software) Encode(b []byte, m *Message) (AttrType, []byte, error) { - return AttrSoftware, append(b, s.Raw...), nil -} - -// Decode implements AttrDecoder. -func (s *Software) Decode(v []byte, m *Message) error { - s.Raw = v - return nil -} - -const ( - fingerprintXORValue uint32 = 0x5354554e -) - -// Fingerprint represents FINGERPRINT attribute. -type Fingerprint struct { - Value uint32 // CRC-32 of message XOR-ed with 0x5354554e -} - -const ( - fingerprintSize = 4 // 32 bit -) - -// AddTo adds fingerprint to message. -func (f *Fingerprint) AddTo(m *Message) error { - l := m.Length - // length in header should include size of fingerprint attribute - m.Length += fingerprintSize + attributeHeaderSize // increasing length - m.WriteLength() // writing Length to Raw - b := make([]byte, fingerprintSize) - f.Value = crc32.ChecksumIEEE(m.Raw) ^ fingerprintXORValue // XOR - bin.PutUint32(b, f.Value) - m.Length = l - m.AddRaw(AttrFingerprint, b) - return nil -} - -// Check reads fingerprint value from m and checks it, returning error if any. -// Can return *DecodeErr, ErrAttributeNotFound, ErrCRCMissmatch. -func (f *Fingerprint) Check(m *Message) error { - v, err := m.getAttrValue(AttrFingerprint) - if err != nil { - return err - } - if len(v) != fingerprintSize { - return newDecodeErr("message", "fingerprint", "bad length") - } - f.Value = bin.Uint32(v) - attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) - expected := crc32.ChecksumIEEE(m.Raw[:attrStart]) ^ fingerprintXORValue - if expected != f.Value { - return ErrCRCMissmatch - } - return nil -} - -// ErrCRCMissmatch means that calculated fingerprint attribute differs from -// expected one. -const ErrCRCMissmatch Error = "CRC32 missmatch: bad fingerprint value" diff --git a/attributes_test.go b/attributes_test.go index d2cbf16..acae3dd 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -10,21 +10,24 @@ import ( "testing" ) -func TestMessage_AddSoftware(t *testing.T) { +func TestSoftware_GetFrom(t *testing.T) { m := New() v := "Client v0.0.1" - m.AddRaw(AttrSoftware, []byte(v)) + m.Add(AttrSoftware, []byte(v)) m.WriteHeader() m2 := &Message{ Raw: make([]byte, 0, 256), } + software := new(Software) if _, err := m2.ReadFrom(m.reader()); err != nil { t.Error(err) } - vRead := m.GetSoftware() - if vRead != v { - t.Errorf("Expected %s, got %s.", v, vRead) + if err := software.GetFrom(m); err != nil { + t.Fatal(err) + } + if software.String() != v { + t.Errorf("Expected %q, got %q.", v, software) } sAttr, ok := m.Attributes.Get(AttrSoftware) @@ -37,29 +40,18 @@ func TestMessage_AddSoftware(t *testing.T) { } } -func TestMessage_GetSoftware(t *testing.T) { - m := New() - v := m.GetSoftware() - if v != "" { - t.Errorf("%s should be blank.", v) - } - vByte := m.GetSoftwareBytes() - if vByte != nil { - t.Errorf("%s should be nil.", vByte) - } -} - -func BenchmarkMessage_AddXORMappedAddress(b *testing.B) { +func BenchmarkXORMappedAddress_AddTo(b *testing.B) { m := New() b.ReportAllocs() ip := net.ParseIP("192.168.1.32") for i := 0; i < b.N; i++ { - m.AddXORMappedAddress(ip, 3654) + addr := &XORMappedAddress{IP: ip, Port: 3654} + addr.AddTo(m) m.Reset() } } -func BenchmarkMessage_GetXORMappedAddress(b *testing.B) { +func BenchmarkXORMappedAddress_GetFrom(b *testing.B) { m := New() transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") if err != nil { @@ -70,10 +62,13 @@ func BenchmarkMessage_GetXORMappedAddress(b *testing.B) { if err != nil { b.Error(err) } + m.Add(AttrXORMappedAddress, addrValue) + addr := new(XORMappedAddress) + b.ReportAllocs() for i := 0; i < b.N; i++ { - m.AddRaw(AttrXORMappedAddress, addrValue) - m.GetXORMappedAddress() - m.Reset() + if err := addr.GetFrom(m); err != nil { + b.Fatal(err) + } } } @@ -88,16 +83,16 @@ func TestMessage_GetXORMappedAddress(t *testing.T) { if err != nil { t.Error(err) } - m.AddRaw(AttrXORMappedAddress, addrValue) - ip, port, err := m.GetXORMappedAddress() - if err != nil { + m.Add(AttrXORMappedAddress, addrValue) + addr := new(XORMappedAddress) + if err = addr.GetFrom(m); err != nil { t.Error(err) } - if !ip.Equal(net.ParseIP("213.141.156.236")) { - t.Error("bad ip", ip, "!=", "213.141.156.236") + if !addr.IP.Equal(net.ParseIP("213.141.156.236")) { + t.Error("bad IP", addr.IP, "!=", "213.141.156.236") } - if port != 48583 { - t.Error("bad port", port, "!=", 48583) + if addr.Port != 48583 { + t.Error("bad Port", addr.Port, "!=", 48583) } } @@ -110,13 +105,15 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) { copy(m.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 + addr := new(XORMappedAddress) - _, _, err = m.GetXORMappedAddress() - if err == nil { + if err = addr.GetFrom(m); err == nil { t.Fatal(err, "should be nil") } - m.AddXORMappedAddress(expectedIP, expectedPort) + addr.IP = expectedIP + addr.Port = expectedPort + addr.AddTo(m) m.WriteHeader() mRes := New() @@ -124,8 +121,7 @@ func TestMessage_GetXORMappedAddressBad(t *testing.T) { if _, err = mRes.ReadFrom(bytes.NewReader(m.Raw)); err != nil { t.Fatal(err) } - _, _, err = m.GetXORMappedAddress() - if err == nil { + if err = addr.GetFrom(m); err == nil { t.Fatal(err, "should not be nil") } } @@ -139,22 +135,26 @@ func TestMessage_AddXORMappedAddress(t *testing.T) { copy(m.TransactionID[:], transactionID) expectedIP := net.ParseIP("213.141.156.236") expectedPort := 21254 - m.AddXORMappedAddress(expectedIP, expectedPort) + addr := &XORMappedAddress{ + IP: net.ParseIP("213.141.156.236"), + Port: expectedPort, + } + if err = addr.AddTo(m); err != nil { + t.Fatal(err) + } m.WriteHeader() - mRes := New() - if _, err = mRes.ReadFrom(m.reader()); err != nil { + if _, err = mRes.Write(m.Raw); err != nil { t.Fatal(err) } - ip, port, err := m.GetXORMappedAddress() - if err != nil { + if err = addr.GetFrom(mRes); err != nil { t.Fatal(err) } - if !ip.Equal(expectedIP) { - t.Error("bad ip", ip, "!=", expectedIP) + if !addr.IP.Equal(expectedIP) { + t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP) } - if port != expectedPort { - t.Error("bad port", port, "!=", expectedPort) + if addr.Port != expectedPort { + t.Error("bad Port", addr.Port, "!=", expectedPort) } } @@ -167,30 +167,47 @@ func TestMessage_AddXORMappedAddressV6(t *testing.T) { copy(m.TransactionID[:], transactionID) expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedPort := 21254 - m.AddXORMappedAddress(expectedIP, expectedPort) + addr := &XORMappedAddress{ + IP: net.ParseIP("fe80::dc2b:44ff:fe20:6009"), + Port: 21254, + } + addr.AddTo(m) m.WriteHeader() mRes := New() if _, err = mRes.ReadFrom(m.reader()); err != nil { t.Fatal(err) } - ip, port, err := m.GetXORMappedAddress() - if err != nil { + gotAddr := new(XORMappedAddress) + if err = gotAddr.GetFrom(m); err != nil { t.Fatal(err) } - if !ip.Equal(expectedIP) { - t.Error("bad ip", ip, "!=", expectedIP) + if !gotAddr.IP.Equal(expectedIP) { + t.Error("bad IP", gotAddr.IP, "!=", expectedIP) } - if port != expectedPort { - t.Error("bad port", port, "!=", expectedPort) + if gotAddr.Port != expectedPort { + t.Error("bad Port", gotAddr.Port, "!=", expectedPort) } } -func BenchmarkMessage_AddErrorCode(b *testing.B) { +func BenchmarkErrorCode_AddTo(b *testing.B) { m := New() b.ReportAllocs() for i := 0; i < b.N; i++ { - m.AddErrorCode(404, "Not found") + CodeStaleNonce.AddTo(m) + m.Reset() + } +} + +func BenchmarkErrorCodeAttribute_AddTo(b *testing.B) { + m := New() + b.ReportAllocs() + a := &ErrorCodeAttribute{ + Code: 404, + Reason: []byte("not found!"), + } + for i := 0; i < b.N; i++ { + a.AddTo(m) m.Reset() } } @@ -202,51 +219,27 @@ func TestMessage_AddErrorCode(t *testing.T) { t.Error(err) } copy(m.TransactionID[:], transactionID) - expectedCode := 404 - expectedReason := "Not found" - m.AddErrorCode(expectedCode, expectedReason) + expectedCode := ErrorCode(428) + expectedReason := "Stale Nonce" + CodeStaleNonce.AddTo(m) m.WriteHeader() mRes := New() if _, err = mRes.ReadFrom(m.reader()); err != nil { t.Fatal(err) } - code, reason, err := mRes.GetErrorCode() + errCodeAttr := new(ErrorCodeAttribute) + if err = errCodeAttr.GetFrom(mRes); err != nil { + t.Error(err) + } + code := errCodeAttr.Code if err != nil { t.Error(err) } if code != expectedCode { t.Error("bad code", code) } - if string(reason) != expectedReason { - t.Error("bad reason", string(reason)) - } -} - -func TestMessage_AddErrorCodeDefault(t *testing.T) { - m := New() - transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") - if err != nil { - t.Error(err) - } - copy(m.TransactionID[:], transactionID) - expectedCode := 500 - expectedReason := "Server Error" - m.AddErrorCodeDefault(expectedCode) - m.WriteHeader() - - mRes := New() - if _, err = mRes.ReadFrom(m.reader()); err != nil { - t.Fatal(err) - } - code, reason, err := mRes.GetErrorCode() - if err != nil { - t.Error(err) - } - if code != expectedCode { - t.Error("bad code", code) - } - if string(reason) != expectedReason { - t.Error("bad reason", string(reason)) + if string(errCodeAttr.Reason) != expectedReason { + t.Error("bad reason", string(errCodeAttr.Reason)) } } diff --git a/errors.go b/errors.go index c37b65c..d47ec5c 100644 --- a/errors.go +++ b/errors.go @@ -15,6 +15,12 @@ type DecodeErr struct { Message string } +// IsInvalidCookie returns true if error means that magic cookie +// value is invalid. +func (e DecodeErr) IsInvalidCookie() bool { + return e.Place == DecodeErrPlace{"message", "cookie"} +} + // IsPlaceParent reports if error place parent is p. func (e DecodeErr) IsPlaceParent(p string) bool { return e.Place.Parent == p @@ -50,3 +56,8 @@ func newDecodeErr(parent, children, message string) *DecodeErr { Message: message, } } + +// TODO(ar): rewrite errors to be more precise. +func newAttrDecodeErr(children, message string) *DecodeErr { + return newDecodeErr("attribute", children, message) +} diff --git a/fuzz.go b/fuzz.go index 5f58581..5d71e2a 100644 --- a/fuzz.go +++ b/fuzz.go @@ -12,12 +12,12 @@ func FuzzMessage(data []byte) int { // fuzzer dont know about cookies binary.BigEndian.PutUint32(data[4:8], magicCookie) // trying to read data as message - if _, err := m.ReadBytes(data); err != nil { + if _, err := m.Write(data); err != nil { return 0 } m.WriteHeader() m2 := New() - if _, err := m2.ReadBytes(m2.Raw); err != nil { + if _, err := m2.Write(m2.Raw); err != nil { panic(err) } if m2.TransactionID != m.TransactionID { diff --git a/message.go b/message.go new file mode 100644 index 0000000..ae40995 --- /dev/null +++ b/message.go @@ -0,0 +1,491 @@ +// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389. +// +// Definitions +// +// STUN Agent: A STUN agent is an entity that implements the STUN +// protocol. The entity can be either a STUN client or a STUN +// server. +// +// STUN Client: A STUN client is an entity that sends STUN requests and +// receives STUN responses. A STUN client can also send indications. +// In this specification, the terms STUN client and client are +// synonymous. +// +// STUN Server: A STUN server is an entity that receives STUN requests +// and sends STUN responses. A STUN server can also send +// indications. In this specification, the terms STUN server and +// server are synonymous. +// +// Transport Address: The combination of an IP address and Port number +// (such as a UDP or TCP Port number). +package stun + +import ( + "crypto/rand" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "strconv" +) + +const ( + // magicCookie is fixed value that aids in distinguishing STUN packets + // from packets of other protocols when STUN is multiplexed with those + // other protocols on the same Port. + // + // The magic cookie field MUST contain the fixed value 0x2112A442 in + // network byte order. + // + // Defined in "STUN Message Structure", section 6. + magicCookie = 0x2112A442 + attributeHeaderSize = 4 + messageHeaderSize = 20 + transactionIDSize = 12 // 96 bit +) + +// NewTransactionID returns new random transaction ID using crypto/rand +// as source. +func NewTransactionID() (b [transactionIDSize]byte) { + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return b +} + +// IsMessage returns true if b looks like STUN message. +// Useful for multiplexing. IsMessage does not guarantee +// that decoding will be successful. +func IsMessage(b []byte) bool { + return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie +} + +// New returns *Message with pre-allocated Raw. +func New() *Message { + const defaultRawCapacity = 120 + return &Message{ + Raw: make([]byte, messageHeaderSize, defaultRawCapacity), + } +} + +// Message represents a single STUN packet. It uses aggressive internal +// buffering to enable zero-allocation encoding and decoding, +// so there are some usage constraints: +// +// * Message and its fields is valid only until AcquireMessage call. +type Message struct { + Type MessageType + Length uint32 // len(Raw) not including header + TransactionID [transactionIDSize]byte + Attributes Attributes + Raw []byte +} + +// NewTransactionID sets m.TransactionID to random value from crypto/rand +// and returns error if any. +func (m *Message) NewTransactionID() error { + _, err := rand.Read(m.TransactionID[:]) + return err +} + +func (m Message) String() string { + return fmt.Sprintf("%s l=%d attrs=%d id=%s", + m.Type, + m.Length, + len(m.Attributes), + base64.StdEncoding.EncodeToString(m.TransactionID[:]), + ) +} + +// Reset resets Message, attributes and underlying buffer length. +func (m *Message) Reset() { + m.Raw = m.Raw[:0] + m.Length = 0 + m.Attributes = m.Attributes[:0] +} + +// grow ensures that internal buffer will fit v more bytes and +// increases it capacity if necessary. +func (m *Message) grow(v int) { + // Not performing any optimizations here + // (e.g. preallocate len(buf) * 2 to reduce allocations) + // because they are already done by []byte implementation. + n := len(m.Raw) + v + for cap(m.Raw) < n { + m.Raw = append(m.Raw, 0) + } + m.Raw = m.Raw[:n] +} + +// Add appends new attribute to message. Not goroutine-safe. +// +// Value of attribute is copied to internal buffer so +// it is safe to reuse v. +func (m *Message) Add(t AttrType, v []byte) { + // Allocating buffer for TLV (type-length-value). + // T = t, L = len(v), V = v. + // m.Raw will look like: + // [0:20] <- message header + // [20:20+m.Length] <- existing message attributes + // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV + // [first:last] <- same as previous + // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer + // T L V + allocSize := attributeHeaderSize + len(v) // len(TLV) = len(TL) + len(V) + first := messageHeaderSize + int(m.Length) // first byte number + last := first + allocSize // last byte number + m.grow(last) // growing cap(Raw) to fit TLV + m.Raw = m.Raw[:last] // now len(Raw) = last + m.Length += uint32(allocSize) // rendering length change + + // Sub-slicing internal buffer to simplify encoding. + buf := m.Raw[first:last] // slice for TLV + value := buf[attributeHeaderSize:] // slice for V + attr := RawAttribute{ + Type: t, // T + Length: uint16(len(v)), // L + Value: value, // V + } + + // Encoding attribute TLV to allocated buffer. + bin.PutUint16(buf[0:2], attr.Type.Value()) // T + bin.PutUint16(buf[2:4], attr.Length) // L + copy(value, v) // V + + // Checking that attribute value needs padding. + if attr.Length%padding != 0 { + // Performing padding. + bytesToAdd := nearestPaddedValueLength(len(v)) - len(v) + last += bytesToAdd + m.grow(last) + // setting all padding bytes to zero + // to prevent data leak from previous + // data in next bytesToAdd bytes + buf = m.Raw[last-bytesToAdd : last] + for i := range buf { + buf[i] = 0 + } + m.Raw = m.Raw[:last] // increasing buffer length + m.Length += uint32(bytesToAdd) // rendering length change + } + m.Attributes = append(m.Attributes, attr) +} + +// Equal returns true if Message b equals to m. +// Ignores m.Raw. +func (m *Message) Equal(b *Message) bool { + if m.Type != b.Type { + return false + } + if m.TransactionID != b.TransactionID { + return false + } + if m.Length != b.Length { + return false + } + for _, a := range m.Attributes { + aB, ok := b.Attributes.Get(a.Type) + if !ok { + return false + } + if !aB.Equal(a) { + return false + } + } + return true +} + +// WriteLength writes m.Length to m.Raw. Call is valid only if len(m.Raw) >= 4. +func (m *Message) WriteLength() { + _ = m.Raw[4] // early bounds check to guarantee safety of writes below + bin.PutUint16(m.Raw[2:4], uint16(m.Length)) +} + +// WriteHeader writes header to underlying buffer. Not goroutine-safe. +func (m *Message) WriteHeader() { + if len(m.Raw) < messageHeaderSize { + // Making WriteHeader call valid even when m.Raw + // is nil or len(m.Raw) is less than needed for header. + m.grow(messageHeaderSize) + } + _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below + + bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type + bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize)) // size of payload + bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie + copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID +} + +// WriteAttributes encodes all m.Attributes to m. +func (m *Message) WriteAttributes() { + for _, a := range m.Attributes { + m.Add(a.Type, a.Value) + } +} + +// Encode resets m.Raw and calls WriteHeader and WriteAttributes. +func (m *Message) Encode() { + m.Raw = m.Raw[:0] + m.WriteHeader() + m.WriteAttributes() +} + +// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning +// call result. +func (m *Message) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.Raw) + return int64(n), err +} + +// Append appends m.Raw to v. Useful to call after encoding message. +func (m *Message) Append(v []byte) []byte { + return append(v, m.Raw...) +} + +// ReadFrom implements ReaderFrom. Reads message from r into m.Raw, +// Decodes it and return error if any. If m.Raw is too small, will return +// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. +// +// Can return *DecodeErr while decoding too. +func (m *Message) ReadFrom(r io.Reader) (int64, error) { + tBuf := m.Raw[:cap(m.Raw)] + var ( + n int + err error + ) + if n, err = r.Read(tBuf); err != nil { + return int64(n), err + } + m.Raw = tBuf[:n] + return int64(n), m.Decode() +} + +const ( + // ErrUnexpectedHeaderEOF means that there were not enough bytes in + // m.Raw to read header. + ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header" +) + +// Decode decodes m.Raw into m. +func (m *Message) Decode() error { + // decoding message header + buf := m.Raw + if len(buf) < messageHeaderSize { + return ErrUnexpectedHeaderEOF + } + var ( + t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes + size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes + cookie = binary.BigEndian.Uint32(buf[4:8]) + fullSize = messageHeaderSize + size + ) + if cookie != magicCookie { + msg := fmt.Sprintf( + "%x is invalid magic cookie (should be %x)", + cookie, magicCookie, + ) + return newDecodeErr("message", "cookie", msg) + } + if len(buf) < fullSize { + msg := fmt.Sprintf( + "buffer length %d is less than %d (expected message size)", + len(buf), fullSize, + ) + return newAttrDecodeErr("message", msg) + } + // saving header data + m.Type.ReadValue(t) + m.Length = uint32(size) + copy(m.TransactionID[:], buf[8:messageHeaderSize]) + + var ( + offset = 0 + b = buf[messageHeaderSize:fullSize] + ) + for offset < size { + // checking that we have enough bytes to read header + if len(b) < attributeHeaderSize { + msg := fmt.Sprintf( + "buffer length %d is less than %d (expected header size)", + len(b), attributeHeaderSize, + ) + return newAttrDecodeErr("header", msg) + } + var ( + a = RawAttribute{ + Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes + Length: bin.Uint16(b[2:4]), // second 2 bytes + } + aL = int(a.Length) // attribute length + aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) + ) + b = b[attributeHeaderSize:] // slicing again to simplify value read + offset += attributeHeaderSize + if len(b) < aBuffL { // checking size + msg := fmt.Sprintf( + "buffer length %d is less than %d (expected value size)", + len(b), aBuffL, + ) + return newAttrDecodeErr("value", msg) + } + a.Value = b[:aL] + offset += aBuffL + b = b[aBuffL:] + + m.Attributes = append(m.Attributes, a) + } + return nil +} + +// Write decodes message and return error if any. +// +// Any error is unrecoverable, but message could be partially decoded. +func (m *Message) Write(tBuf []byte) (int, error) { + m.Raw = append(m.Raw[:0], tBuf...) + return len(tBuf), m.Decode() +} + +// MaxPacketSize is maximum size of UDP packet that is processable in +// this package for STUN message. +const MaxPacketSize = 2048 + +// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. +type MessageClass byte + +// Possible values for message class in STUN Message Type. +const ( + ClassRequest MessageClass = 0x00 // 0b00 + ClassIndication MessageClass = 0x01 // 0b01 + ClassSuccessResponse MessageClass = 0x02 // 0b10 + ClassErrorResponse MessageClass = 0x03 // 0b11 +) + +func (c MessageClass) String() string { + switch c { + case ClassRequest: + return "request" + case ClassIndication: + return "indication" + case ClassSuccessResponse: + return "success response" + case ClassErrorResponse: + return "error response" + default: + panic("unknown message class") + } +} + +// Method is uint16 representation of 12-bit STUN method. +type Method uint16 + +// Possible methods for STUN Message. +const ( + MethodBinding Method = 0x001 + MethodAllocate Method = 0x003 + MethodRefresh Method = 0x004 + MethodSend Method = 0x006 + MethodData Method = 0x007 + MethodCreatePermission Method = 0x008 + MethodChannelBind Method = 0x009 +) + +func (m Method) String() string { + switch m { + case MethodBinding: + return "binding" + case MethodAllocate: + return "allocate" + case MethodRefresh: + return "refresh" + case MethodSend: + return "send" + case MethodData: + return "data" + case MethodCreatePermission: + return "create permission" + case MethodChannelBind: + return "channel bind" + default: + return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16)) + } +} + +// MessageType is STUN Message Type Field. +type MessageType struct { + Class MessageClass + Method Method +} + +const ( + methodABits = 0xf // 0b0000000000001111 + methodBBits = 0x70 // 0b0000000001110000 + methodDBits = 0xf80 // 0b0000111110000000 + + methodBShift = 1 + methodDShift = 2 + + firstBit = 0x1 + secondBit = 0x2 + + c0Bit = firstBit + c1Bit = secondBit + + classC0Shift = 4 + classC1Shift = 7 +) + +// Value returns bit representation of messageType. +func (t MessageType) Value() uint16 { + // 0 1 + // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // Figure 3: Format of STUN Message Type Field + + // Warning: Abandon all hope ye who enter here. + // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). + m := uint16(t.Method) + a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) + b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) + d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) + + // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). + m = a + (b << methodBShift) + (d << methodDShift) + + // C0 is zero bit of C, C1 is fist bit. + // C0 = C * 0b01, C1 = (C * 0b10) >> 1 + // Ct = C0 << 4 + C1 << 8. + // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" + // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions + // (see figure 3). + c := uint16(t.Class) + c0 := (c & c0Bit) << classC0Shift + c1 := (c & c1Bit) << classC1Shift + class := c0 + c1 + + return m + class +} + +// ReadValue decodes uint16 into MessageType. +func (t *MessageType) ReadValue(v uint16) { + // Decoding class. + // We are taking first bit from v >> 4 and second from v >> 7. + c0 := (v >> classC0Shift) & c0Bit + c1 := (v >> classC1Shift) & c1Bit + class := c0 + c1 + t.Class = MessageClass(class) + + // Decoding method. + a := v & methodABits // A(M0-M3) + b := (v >> methodBShift) & methodBBits // B(M4-M6) + d := (v >> methodDShift) & methodDBits // D(M7-M11) + m := a + b + d + t.Method = Method(m) +} + +func (t MessageType) String() string { + return fmt.Sprintf("%s %s", t.Method, t.Class) +} diff --git a/padding.go b/padding.go new file mode 100644 index 0000000..5b51cd0 --- /dev/null +++ b/padding.go @@ -0,0 +1,13 @@ +package stun + +const ( + padding = 4 +) + +func nearestPaddedValueLength(l int) int { + n := padding * (l / padding) + if n < l { + n += padding + } + return n +} diff --git a/stun-bench/main.go b/stun-bench/main.go index 0447ab2..36e9c85 100644 --- a/stun-bench/main.go +++ b/stun-bench/main.go @@ -6,7 +6,7 @@ import ( "log" "net" "net/http" - "runtime" + "sync/atomic" "time" "github.com/ernado/stun" @@ -17,13 +17,14 @@ var ( fmt.Sprintf("127.0.0.1:%d", stun.DefaultPort), "addr to attack", ) + readWorkers = flag.Int("read-workers", 1, "concurrent read workers") + writeWorkers = flag.Int("write-workers", 1, "concurrent write workers") count int64 ) func main() { flag.Parse() - runtime.GOMAXPROCS(2) log.SetFlags(log.Lshortfile) go func() { log.Println(http.ListenAndServe("localhost:6060", nil)) @@ -43,30 +44,42 @@ func main() { }, TransactionID: stun.NewTransactionID(), } - m.AddRaw(stun.AttrSoftware, []byte("stun benchmark")) + m.Add(stun.AttrSoftware, []byte("stun benchmark")) m.Encode() - go func() { - mRec := stun.New() - mRec.Raw = make([]byte, 1024) - start := time.Now() - for { - _, err := c.Read(mRec.Raw[:cap(mRec.Raw)]) - if err != nil { - log.Fatalln("read back:", err) + for i := 0; i < *readWorkers; i++ { + go func() { + mRec := stun.New() + mRec.Raw = make([]byte, 1024) + start := time.Now() + for { + _, err := c.Read(mRec.Raw[:cap(mRec.Raw)]) + if err != nil { + log.Fatalln("read back:", err) + } + // mRec.Raw = mRec.Raw[:n] + // if err := mRec.Decode(); err != nil { + // log.Fatalln("Decode:", err) + // } + atomic.AddInt64(&count, 1) + if c := atomic.LoadInt64(&count); c%10000 == 0 { + fmt.Printf("%d\n", c) + elapsed := time.Since(start) + fmt.Println(float64(c)/elapsed.Seconds(), "per second") + } + // mRec.Reset() } - // mRec.Raw = mRec.Raw[:n] - // if err := mRec.Decode(); err != nil { - // log.Fatalln("Decode:", err) - // } - count++ - if count%10000 == 0 { - fmt.Printf("%d\n", count) - elapsed := time.Since(start) - fmt.Println(float64(count)/elapsed.Seconds(), "per second") + }() + } + for i := 1; i < *writeWorkers; i++ { + go func() { + for { + _, err := c.Write(m.Raw) + if err != nil { + log.Fatalln("write:", err) + } } - // mRec.Reset() - } - }() + }() + } for { _, err := c.Write(m.Raw) if err != nil { diff --git a/stun-cli/client_test.go b/stun-cli/client_test.go index 041caa5..4da953c 100644 --- a/stun-cli/client_test.go +++ b/stun-cli/client_test.go @@ -33,7 +33,7 @@ func TestClient_Do(t *testing.T) { m := stun.New() m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} m.TransactionID = stun.NewTransactionID() - m.Add(stun.NewSoftware("cydev/stun alpha")) + stun.NewSoftware("cydev/stun alpha").AddTo(m) m.WriteHeader() request := Request{ Target: "stun.l.google.com:19302", @@ -43,11 +43,11 @@ func TestClient_Do(t *testing.T) { if r.Message.TransactionID != m.TransactionID { t.Error("transaction id messmatch") } - ip, port, err := r.Message.GetXORMappedAddress() - if err != nil { + addr := new(stun.XORMappedAddress) + if err := addr.GetFrom(m); err != nil { t.Error(err) } - log.Println("got", ip, port) + log.Println("got", addr) return nil }); err != nil { t.Fatal(err) diff --git a/stun-cli/main.go b/stun-cli/main.go index b9bb42a..f83c686 100644 --- a/stun-cli/main.go +++ b/stun-cli/main.go @@ -190,7 +190,7 @@ func discover(c *cli.Context) error { Class: stun.ClassRequest, }, } - m.AddRaw(stun.AttrSoftware, software.Raw) + m.Add(stun.AttrSoftware, software.Raw) m.WriteHeader() request := Request{ @@ -200,15 +200,13 @@ func discover(c *cli.Context) error { return DefaultClient.Do(request, func(r Response) error { var ( - ip net.IP - port int - err error + err error ) - ip, port, err = r.Message.GetXORMappedAddress() - if err != nil { + addr := new(stun.XORMappedAddress) + if err = addr.GetFrom(r.Message); err != nil { return errors.Wrap(err, "failed to get ip") } - fmt.Println(ip, port) + fmt.Println(addr) return nil }) } diff --git a/stun-decode/main.go b/stun-decode/main.go index 03e1b81..92ef474 100644 --- a/stun-decode/main.go +++ b/stun-decode/main.go @@ -23,7 +23,7 @@ func main() { log.Fatalln("Unable to decode bas64 value:", err) } m := stun.New() - if _, err = m.ReadBytes(data); err != nil { + if _, err = m.Write(data); err != nil { log.Fatalln("Unable to decode message:", err) } fmt.Println(m) diff --git a/stun.go b/stun.go index e2d69b5..606ef2f 100644 --- a/stun.go +++ b/stun.go @@ -16,542 +16,14 @@ // indications. In this specification, the terms STUN server and // server are synonymous. // -// Transport Address: The combination of an IP address and port number -// (such as a UDP or TCP port number). +// Transport Address: The combination of an IP address and Port number +// (such as a UDP or TCP Port number). package stun -import ( - "crypto/rand" - "encoding/base64" - "encoding/binary" - "fmt" - "io" - "net" - "strconv" -) +import "encoding/binary" -var ( - // bin is shorthand to binary.BigEndian. - bin = binary.BigEndian -) +// bin is shorthand to binary.BigEndian. +var bin = binary.BigEndian -// DefaultPort is IANA assigned port for "stun" protocol. +// DefaultPort is IANA assigned Port for "stun" protocol. const DefaultPort = 3478 - -const ( - // magicCookie is fixed value that aids in distinguishing STUN packets - // from packets of other protocols when STUN is multiplexed with those - // other protocols on the same port. - // - // The magic cookie field MUST contain the fixed value 0x2112A442 in - // network byte order. - // - // Defined in "STUN Message Structure", section 6. - magicCookie = 0x2112A442 -) - -const transactionIDSize = 12 // 96 bit - -// Message represents a single STUN packet. It uses aggressive internal -// buffering to enable zero-allocation encoding and decoding, -// so there are some usage constraints: -// -// * Message and its fields is valid only until AcquireMessage call. -type Message struct { - Type MessageType - Length uint32 // len(Raw) not including header - TransactionID [transactionIDSize]byte - Attributes Attributes - Raw []byte -} - -// CopyTo copies all m to c. -func (m Message) CopyTo(c *Message) { - c.Type = m.Type - c.Length = m.Length - copy(c.TransactionID[:], m.TransactionID[:]) - buf := m.Raw[:int(m.Length)+messageHeaderSize] - c.Raw = c.Raw[:0] - c.Raw = append(c.Raw, buf...) - buf = c.Raw[messageHeaderSize:] - for _, a := range m.Attributes { - buf = buf[attributeHeaderSize:] - c.Attributes = append(c.Attributes, RawAttribute{ - Length: a.Length, - Type: a.Type, - Value: buf[:int(a.Length)], - }) - buf = buf[int(a.Length):] - } -} - -func (m Message) String() string { - return fmt.Sprintf("%s (l=%d,%d/%d) attr[%d] id[%s]", - m.Type, - m.Length, - len(m.Raw), - cap(m.Raw), - len(m.Attributes), - base64.StdEncoding.EncodeToString(m.TransactionID[:]), - ) -} - -// NewTransactionID returns new random transaction ID using crypto/rand -// as source. -func NewTransactionID() (b [transactionIDSize]byte) { - _, err := rand.Read(b[:]) - if err != nil { - panic(err) - } - return b -} - -// defaults for pool. -const ( - defaultMessageBufferCapacity = 120 -) - -// New returns *Message with allocated Raw. -func New() *Message { - return &Message{ - Raw: make([]byte, messageHeaderSize, defaultMessageBufferCapacity), - } -} - -// Reset resets Message length, attributes and underlying buffer, as well as -// setting readOnly flag to false. -func (m *Message) Reset() { - m.Raw = m.Raw[:0] - m.Length = 0 - m.Attributes = m.Attributes[:0] -} - -// grow ensures that internal buffer will fit v more bytes and -// increases it capacity if necessary. -func (m *Message) grow(v int) { - // Not performing any optimizations here - // (e.g. preallocate len(buf) * 2 to reduce allocations) - // because they are already done by []byte implementation. - n := len(m.Raw) + v - for cap(m.Raw) < n { - m.Raw = append(m.Raw, 0) - } - m.Raw = m.Raw[:n] -} - -// Add adds AttrEncoder to message, calling Encode method. -func (m *Message) Add(a AttrEncoder) error { - var ( - err error - t AttrType - ) - initial := len(m.Raw) - for i := 0; i < attributeHeaderSize; i++ { - m.Raw = append(m.Raw, 0) - } - start := len(m.Raw) - t, m.Raw, err = a.Encode(m.Raw, m) - if err != nil { - m.Raw = m.Raw[:initial] - return err - } - m.AddRaw(t, m.Raw[start:]) - return nil -} - -// AddRaw appends new attribute to message. Not goroutine-safe. -// -// Value of attribute is copied to internal buffer so -// it is safe to reuse v. -func (m *Message) AddRaw(t AttrType, v []byte) { - // CPU: suboptimal; - // allocating memory for TLV (type-length-value), where - // type-length is attribute header. - // m.buf.B[0:20] is reserved by header - // internal buffer will look like: - // [0:20] <- message header - // [20:20+m.Length] <- added message attributes - // [20+m.Length:20+m.Length+len(v) + 4] <- allocated slice for new TLV - // [first:last] <- same as previous - // [0 1|2 3|4 4 + len(v)] - // T L V - allocSize := attributeHeaderSize + len(v) // len(TLV) - first := messageHeaderSize + int(m.Length) // first byte number - last := first + allocSize // last byte number - m.grow(last) // growing cap(b) to fit TLV - m.Raw = m.Raw[:last] // now len(b) = last - m.Length += uint32(allocSize) // rendering length change - // subslicing internal buffer to simplify encoding - buf := m.Raw[first:last] // slice for TLV - value := buf[attributeHeaderSize:] // slice for value - attr := RawAttribute{ - Type: t, - Value: value, - Length: uint16(len(v)), - } - // encoding attribute TLV to internal buffer - bin.PutUint16(buf[0:2], attr.Type.Value()) // type - bin.PutUint16(buf[2:4], attr.Length) // length - copy(value, v) // value - if attr.Length%padding != 0 { - // performing padding - bytesToAdd := nearestLength(len(v)) - len(v) - last += bytesToAdd - m.grow(last) - // setting all padding bytes to zero - // to prevent data leak from previous - // data in next bytesToAdd bytes - buf = m.Raw[last-bytesToAdd : last] - for i := range buf { - buf[i] = 0 - } - m.Raw = m.Raw[:last] // increasing buffer length - m.Length += uint32(bytesToAdd) // rendering length change - } - m.Attributes = append(m.Attributes, attr) -} - -const ( - padding = 4 -) - -func nearestLength(l int) int { - n := padding * (l / padding) - if n < l { - n += padding - } - return n -} - -// Equal returns true if Message b equals to m. -func (m *Message) Equal(b *Message) bool { - if m.Type != b.Type { - return false - } - if m.TransactionID != b.TransactionID { - return false - } - if m.Length != b.Length { - return false - } - for _, a := range m.Attributes { - aB, ok := b.Attributes.Get(a.Type) - if !ok { - return false - } - if !aB.Equal(a) { - return false - } - } - return true -} - -// WriteLength writes m.Length to m.Raw. -func (m *Message) WriteLength() { bin.PutUint16(m.Raw[2:4], uint16(m.Length)) } - -// WriteHeader writes header to underlying buffer. Not goroutine-safe. -func (m *Message) WriteHeader() { - // encoding header - if len(m.Raw) < messageHeaderSize { - m.grow(messageHeaderSize) - } - bin.PutUint16(m.Raw[0:2], m.Type.Value()) - bin.PutUint32(m.Raw[4:8], magicCookie) - copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) - // attributes are already encoded - // writing length as size, in bytes, not including the 20-byte STUN header. - bin.PutUint16(m.Raw[2:4], uint16(len(m.Raw)-messageHeaderSize)) -} - -// WriteAttributes encodes all m.Attributes to m. -func (m *Message) WriteAttributes() { - for _, a := range m.Attributes { - m.AddRaw(a.Type, a.Value) - } -} - -// Encode writes m into Raw. -func (m *Message) Encode() { - m.Raw = m.Raw[:0] - m.WriteHeader() - m.WriteAttributes() -} - -// WriteTo implements WriterTo. -func (m *Message) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(m.Raw) - return int64(n), err -} - -// WriteToConn writes a packet with message to addr, using c. -// Deprecated; non-idiomatic. -func (m *Message) WriteToConn(c net.PacketConn, addr net.Addr) (n int, err error) { - return c.WriteTo(m.Raw, addr) -} - -// ReadFrom implements ReaderFrom. Reads message from r into m.Raw, -// Decodes it and return error if any. If m.Raw is too small, will return -// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. -// -// Can return *DecodeErr while decoding. -func (m *Message) ReadFrom(r io.Reader) (int64, error) { - tBuf := m.Raw[:cap(m.Raw)] - var ( - n int - err error - ) - if n, err = r.Read(tBuf); err != nil { - return int64(n), err - } - m.Raw = tBuf[:n] - return int64(n), m.Decode() -} - -func newAttrDecodeErr(children, message string) *DecodeErr { - return newDecodeErr("attribute", children, message) -} - -// IsMessage returns true if b looks like STUN message. -// Useful for multiplexing. IsMessage does not guarantee -// that decoding will be successful. -func IsMessage(b []byte) bool { - // b should be at least messageHeaderSize bytes and - // contain correct magicCookie. - return len(b) >= messageHeaderSize && - binary.BigEndian.Uint32(b[4:8]) == magicCookie -} - -const ( - // ErrUnexpectedHeaderEOF means that there were not enough bytes in - // m.Raw to read header. - ErrUnexpectedHeaderEOF Error = "unexpected EOF: not enough bytes to read header" -) - -// Decode decodes m.Raw into m. -func (m *Message) Decode() error { - // decoding message header - buf := m.Raw - if len(buf) < messageHeaderSize { - return ErrUnexpectedHeaderEOF - } - var ( - t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes - size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes - cookie = binary.BigEndian.Uint32(buf[4:8]) - fullSize = messageHeaderSize + size - ) - if cookie != magicCookie { - msg := fmt.Sprintf( - "%x is invalid magic cookie (should be %x)", - cookie, magicCookie, - ) - return newDecodeErr("message", "cookie", msg) - } - if len(buf) < fullSize { - msg := fmt.Sprintf( - "buffer length %d is less than %d (expected message size)", - len(buf), fullSize, - ) - return newAttrDecodeErr("message", msg) - } - // saving header data - m.Type.ReadValue(t) - m.Length = uint32(size) - copy(m.TransactionID[:], buf[8:messageHeaderSize]) - - var ( - offset = 0 - b = buf[messageHeaderSize:fullSize] - ) - for offset < size { - // checking that we have enough bytes to read header - if len(b) < attributeHeaderSize { - msg := fmt.Sprintf( - "buffer length %d is less than %d (expected header size)", - len(b), attributeHeaderSize, - ) - return newAttrDecodeErr("header", msg) - } - var ( - a = RawAttribute{ - Type: AttrType(bin.Uint16(b[0:2])), // first 2 bytes - Length: bin.Uint16(b[2:4]), // second 2 bytes - } - aL = int(a.Length) // attribute length - aBuffL = nearestLength(aL) // expected buffer length (with padding) - ) - b = b[attributeHeaderSize:] // slicing again to simplify value read - offset += attributeHeaderSize - if len(b) < aBuffL { // checking size - msg := fmt.Sprintf( - "buffer length %d is less than %d (expected value size)", - len(b), aBuffL, - ) - return newAttrDecodeErr("value", msg) - } - a.Value = b[:aL] - offset += aBuffL - b = b[aBuffL:] - - m.Attributes = append(m.Attributes, a) - } - return nil -} - -// ReadBytes decodes message and return error if any. -// -// Any error is unrecoverable, but message could be partially decoded. -// Deprecated: use m.Decode. -func (m *Message) ReadBytes(tBuf []byte) (int, error) { - m.Raw = append(m.Raw[:0], tBuf...) - return len(tBuf), m.Decode() -} - -const ( - attributeHeaderSize = 4 - messageHeaderSize = 20 -) - -// MaxPacketSize is maximum size of UDP packet that is processable in -// this package for STUN message. -const MaxPacketSize = 2048 - -// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. -type MessageClass byte - -// Possible values for message class in STUN Message Type. -const ( - ClassRequest MessageClass = 0x00 // 0b00 - ClassIndication MessageClass = 0x01 // 0b01 - ClassSuccessResponse MessageClass = 0x02 // 0b10 - ClassErrorResponse MessageClass = 0x03 // 0b11 -) - -func (c MessageClass) String() string { - switch c { - case ClassRequest: - return "request" - case ClassIndication: - return "indication" - case ClassSuccessResponse: - return "success response" - case ClassErrorResponse: - return "error response" - default: - panic("unknown message class") - } -} - -// Method is uint16 representation of 12-bit STUN method. -type Method uint16 - -// Possible methods for STUN Message. -const ( - MethodBinding Method = 0x001 - MethodAllocate Method = 0x003 - MethodRefresh Method = 0x004 - MethodSend Method = 0x006 - MethodData Method = 0x007 - MethodCreatePermission Method = 0x008 - MethodChannelBind Method = 0x009 -) - -func (m Method) String() string { - switch m { - case MethodBinding: - return "binding" - case MethodAllocate: - return "allocate" - case MethodRefresh: - return "refresh" - case MethodSend: - return "send" - case MethodData: - return "data" - case MethodCreatePermission: - return "create permission" - case MethodChannelBind: - return "channel bind" - default: - return fmt.Sprintf("0x%s", strconv.FormatUint(uint64(m), 16)) - } -} - -// MessageType is STUN Message Type Field. -type MessageType struct { - Class MessageClass - Method Method -} - -const ( - methodABits = 0xf // 0b0000000000001111 - methodBBits = 0x70 // 0b0000000001110000 - methodDBits = 0xf80 // 0b0000111110000000 - - methodBShift = 1 - methodDShift = 2 - - firstBit = 0x1 - secondBit = 0x2 - - c0Bit = firstBit - c1Bit = secondBit - - classC0Shift = 4 - classC1Shift = 7 -) - -// Value returns bit representation of messageType. -func (t MessageType) Value() uint16 { - // 0 1 - // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - // Figure 3: Format of STUN Message Type Field - - // warning: Abandon all hope ye who enter here. - // splitting M into A(M0-M3), B(M4-M6), D(M7-M11) - m := uint16(t.Method) - a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits) - b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A) - d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B) - - // shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit) - m = a + (b << methodBShift) + (d << methodDShift) - - // C0 is zero bit of C, C1 is fist bit. - // C0 = C * 0b01, C1 = (C * 0b10) >> 1 - // Ct = C0 << 4 + C1 << 8. - // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" - // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions - // (see figure 3). - c := uint16(t.Class) - c0 := (c & c0Bit) << classC0Shift - c1 := (c & c1Bit) << classC1Shift - class := c0 + c1 - - return m + class -} - -// ReadValue decodes uint16 into MessageType. -func (t *MessageType) ReadValue(v uint16) { - // decoding class - // we are taking first bit from v >> 4 and second from v >> 7. - c0 := (v >> classC0Shift) & c0Bit - c1 := (v >> classC1Shift) & c1Bit - class := c0 + c1 - t.Class = MessageClass(class) - - // decoding method - a := v & methodABits // A(M0-M3) - b := (v >> methodBShift) & methodBBits // B(M4-M6) - d := (v >> methodDShift) & methodDBits // D(M7-M11) - m := a + b + d - t.Method = Method(m) -} - -func (t MessageType) String() string { - return fmt.Sprintf("%s %s", t.Method, t.Class) -} diff --git a/stun_test.go b/stun_test.go index 3ccf3ee..a0766ed 100644 --- a/stun_test.go +++ b/stun_test.go @@ -9,14 +9,13 @@ import ( "hash/crc64" "io" "io/ioutil" + "net" "os" "path/filepath" "strconv" "strings" "testing" - "net" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -33,24 +32,11 @@ func (m *Message) reader() *bytes.Reader { return bytes.NewReader(m.Raw) } -func TestMessageCopy(t *testing.T) { - m := New() - m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} - m.TransactionID = NewTransactionID() - m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) - m.WriteHeader() - mCopy := New() - m.CopyTo(mCopy) - if !mCopy.Equal(m) { - t.Error(mCopy, "!=", m) - } -} - func TestMessageBuffer(t *testing.T) { m := New() m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.TransactionID = NewTransactionID() - m.AddRaw(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) + m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.WriteHeader() mDecoded := New() if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil { @@ -69,7 +55,7 @@ func BenchmarkMessage_Write(b *testing.B) { transactionID := NewTransactionID() m := New() for i := 0; i < b.N; i++ { - m.AddRaw(AttrErrorCode, attributeValue) + m.Add(AttrErrorCode, attributeValue) m.TransactionID = transactionID m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.WriteHeader() @@ -133,6 +119,25 @@ func TestMessageType_ReadWriteValue(t *testing.T) { } } +func TestMessage_WriteTo(t *testing.T) { + m := New() + m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} + m.TransactionID = NewTransactionID() + m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) + m.WriteHeader() + buf := new(bytes.Buffer) + if _, err := m.WriteTo(buf); err != nil { + t.Fatal(err) + } + mDecoded := New() + if _, err := mDecoded.ReadFrom(buf); err != nil { + t.Error(err) + } + if !mDecoded.Equal(m) { + t.Error(mDecoded, "!", m) + } +} + func TestMessage_Cookie(t *testing.T) { buf := make([]byte, 20) mDecoded := New() @@ -156,11 +161,11 @@ func TestMessage_BadLength(t *testing.T) { Length: 4, TransactionID: [transactionIDSize]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, } - m.AddRaw(0x1, []byte{1, 2}) + m.Add(0x1, []byte{1, 2}) m.WriteHeader() m.Raw[20+3] = 10 // set attr length = 10 mDecoded := New() - if _, err := mDecoded.ReadBytes(m.Raw); err == nil { + if _, err := mDecoded.Write(m.Raw); err == nil { t.Error("should error") } } @@ -295,7 +300,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) { b.SetBytes(int64(len(m.Raw))) mRec := New() for i := 0; i < b.N; i++ { - if _, err := mRec.ReadBytes(m.Raw); err != nil { + if _, err := mRec.Write(m.Raw); err != nil { b.Fatal(err) } mRec.Reset() @@ -439,9 +444,7 @@ func TestMessage_String(t *testing.T) { func TestIsMessage(t *testing.T) { m := New() - m.Add(&Software{ - Raw: []byte("test"), - }) + NewSoftware("software").AddTo(m) m.WriteHeader() var tt = [...]struct { @@ -468,7 +471,7 @@ func BenchmarkIsMessage(b *testing.B) { m := New() m.Type = MessageType{Method: MethodBinding, Class: ClassRequest} m.TransactionID = NewTransactionID() - m.AddSoftware("cydev/stun test") + NewSoftware("cydev/stun test").AddTo(m) m.WriteHeader() b.SetBytes(int64(messageHeaderSize)) @@ -502,13 +505,13 @@ func loadData(tb testing.TB, name string) []byte { func TestExampleChrome(t *testing.T) { buf := loadData(t, "ex1_chrome.stun") m := New() - _, err := m.ReadBytes(buf) + _, err := m.Write(buf) if err != nil { t.Errorf("Failed to parse ex1_chrome: %s", err) } } -func TestNearestLen(t *testing.T) { +func TestPadding(t *testing.T) { tt := []struct { in, out int }{ @@ -525,7 +528,7 @@ func TestNearestLen(t *testing.T) { {40, 40}, // 10 } for i, c := range tt { - if got := nearestLength(c.in); got != c.out { + if got := nearestPaddedValueLength(c.in); got != c.out { t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", i, c.in, got, c.out, ) @@ -562,7 +565,7 @@ func TestMessageFromBrowsers(t *testing.T) { if b != crc64.Checksum(data, crcTable) { t.Error("crc64 check failed for ", line[1]) } - if _, err = m.ReadBytes(data); err != nil { + if _, err = m.Write(data); err != nil { t.Error("failed to decode ", line[1], " as message: ", err) } m.Reset() @@ -583,20 +586,29 @@ func BenchmarkNewTransactionID(b *testing.B) { } } +func BenchmarkMessage_NewTransactionID(b *testing.B) { + b.ReportAllocs() + m := new(Message) + for i := 0; i < b.N; i++ { + if err := m.NewTransactionID(); err != nil { + b.Fatal(err) + } + } +} + func BenchmarkMessageFull(b *testing.B) { b.ReportAllocs() m := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ - ip: net.IPv4(213, 1, 223, 5), + IP: net.IPv4(213, 1, 223, 5), } - fingerprint := new(Fingerprint) for i := 0; i < b.N; i++ { - m.Add(addr) - m.Add(s) + addAttr(b, m, addr) + addAttr(b, m, s) m.WriteAttributes() m.WriteHeader() - fingerprint.AddTo(m) + Fingerprint.AddTo(m) m.WriteHeader() m.Reset() } @@ -607,19 +619,26 @@ func BenchmarkMessageFullHardcore(b *testing.B) { m := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ - ip: net.IPv4(213, 1, 223, 5), + IP: net.IPv4(213, 1, 223, 5), } - t, v, _ := addr.Encode(nil, m) for i := 0; i < b.N; i++ { - m.AddRaw(AttrSoftware, s.Raw) - m.AddRaw(t, v) - m.WriteAttributes() + if err := addr.AddTo(m); err != nil { + b.Fatal(err) + } + if err := s.AddTo(m); err != nil { + b.Fatal(err) + } + m.WriteHeader() m.Reset() } } -func addAttr(t testing.TB, m *Message, a AttrEncoder) { - if err := m.Add(a); err != nil { +type attributeEncoder interface { + AddTo(m *Message) error +} + +func addAttr(t testing.TB, m *Message, a attributeEncoder) { + if err := a.AddTo(m); err != nil { t.Error(err) } } @@ -629,14 +648,13 @@ func BenchmarkFingerprint_AddTo(b *testing.B) { m := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ - ip: net.IPv4(213, 1, 223, 5), + IP: net.IPv4(213, 1, 223, 5), } addAttr(b, m, addr) addAttr(b, m, s) - fingerprint := new(Fingerprint) b.SetBytes(int64(len(m.Raw))) for i := 0; i < b.N; i++ { - fingerprint.AddTo(m) + Fingerprint.AddTo(m) m.WriteLength() m.Length -= attributeHeaderSize + fingerprintSize m.Raw = m.Raw[:m.Length+messageHeaderSize] @@ -648,12 +666,10 @@ func TestFingerprint_Check(t *testing.T) { m := new(Message) addAttr(t, m, NewSoftware("software")) m.WriteHeader() - fInit := new(Fingerprint) - fInit.AddTo(m) + Fingerprint.AddTo(m) m.WriteHeader() - f := new(Fingerprint) - if err := f.Check(m); err != nil { - t.Errorf("decoded: %d, encoded: %d, err: %s", f.Value, fInit.Value, err) + if err := Fingerprint.Check(m); err != nil { + t.Error(err) } } @@ -662,17 +678,16 @@ func BenchmarkFingerprint_Check(b *testing.B) { m := new(Message) s := NewSoftware("software") addr := &XORMappedAddress{ - ip: net.IPv4(213, 1, 223, 5), + IP: net.IPv4(213, 1, 223, 5), } addAttr(b, m, addr) addAttr(b, m, s) - fingerprint := new(Fingerprint) m.WriteHeader() - fingerprint.AddTo(m) + Fingerprint.AddTo(m) m.WriteHeader() b.SetBytes(int64(len(m.Raw))) for i := 0; i < b.N; i++ { - if err := fingerprint.Check(m); err != nil { + if err := Fingerprint.Check(m); err != nil { b.Fatal(err) } } diff --git a/stund/main.go b/stund/main.go index 9c6244e..643940d 100644 --- a/stund/main.go +++ b/stund/main.go @@ -5,9 +5,11 @@ import ( "fmt" "log" "net" - "runtime" + "net/http" "strings" + _ "net/http/pprof" + "github.com/ernado/stun" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -16,6 +18,8 @@ import ( var ( network = flag.String("net", "udp", "network to listen") address = flag.String("addr", "0.0.0.0:3478", "address to listen") + workers = flag.Int("workers", 1, "workers to start") + profile = flag.Bool("profile", false, "profile") ) // Server is RFC 5389 basic server implementation. @@ -80,7 +84,7 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error { if !stun.IsMessage(b) { return errNotSTUNMessage } - if _, err := req.ReadBytes(b); err != nil { + if _, err := req.Write(b); err != nil { return errors.Wrap(err, "failed to read message") } res.TransactionID = req.TransactionID @@ -99,8 +103,12 @@ func basicProcess(addr net.Addr, b []byte, req, res *stun.Message) error { default: panic(fmt.Sprintf("unknown addr: %v", addr)) } - res.AddXORMappedAddress(ip, port) - res.AddRaw(stun.AttrSoftware, software.Raw) + xma := &stun.XORMappedAddress{ + IP: ip, + Port: port, + } + xma.AddTo(res) + software.AddTo(res) res.WriteHeader() return nil } @@ -116,8 +124,8 @@ func (s *Server) serveConn(c net.PacketConn, res, req *stun.Message) error { return nil } // s.logger().Printf("read %d bytes from %s", n, addr) - if _, err = req.ReadBytes(buf[:n]); err != nil { - s.logger().Printf("ReadBytes: %v", err) + if _, err = req.Write(buf[:n]); err != nil { + s.logger().Printf("Write: %v", err) return err } if err = basicProcess(addr, buf[:n], req, res); err != nil { @@ -156,8 +164,10 @@ func ListenUDPAndServe(serverNet, laddr string) error { if err != nil { return err } - s := &Server{} - return s.Serve(c) + for i := 1; i < *workers; i++ { + go new(Server).Serve(c) + } + return new(Server).Serve(c) } func normalize(address string) string { @@ -172,7 +182,11 @@ func normalize(address string) string { func main() { flag.Parse() - runtime.GOMAXPROCS(1) + if *profile { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + } switch *network { case "udp": normalized := normalize(*address) diff --git a/stund/server_test.go b/stund/server_test.go index 026c470..9708229 100644 --- a/stund/server_test.go +++ b/stund/server_test.go @@ -17,7 +17,7 @@ func BenchmarkBasicProcess(b *testing.B) { b.Fatal(err) } m.TransactionID = stun.NewTransactionID() - m.Add(stun.NewSoftware("some software")) + stun.NewSoftware("some software").AddTo(m) m.WriteHeader() b.SetBytes(int64(len(m.Raw))) for i := 0; i < b.N; i++ { diff --git a/xor.go b/xor.go index 5495031..34365eb 100644 --- a/xor.go +++ b/xor.go @@ -34,7 +34,7 @@ func fastXORBytes(dst, a, b []byte) int { } } - for i := (n - n%wordSize); i < n; i++ { + for i := n - n%wordSize; i < n; i++ { dst[i] = a[i] ^ b[i] }