a:addr: implemented with ALTERNATE-SERVER

This commit is contained in:
Aleksandr Razumov
2017-02-12 23:25:19 +03:00
parent f4af1f1a53
commit 4cb0dd231a
3 changed files with 201 additions and 4 deletions

105
addr.go Normal file
View File

@@ -0,0 +1,105 @@
package stun
import (
"fmt"
"net"
"strconv"
)
// MappedAddress represents MAPPED-ADDRESS attribute.
//
// This attribute is used only by servers for achieving backwards
// compatibility with RFC 3489 clients.
// https://tools.ietf.org/html/rfc5389#section-15.4
type MappedAddress struct {
IP net.IP
Port int
}
// AlternateServer represents ALTERNATE-SERVER attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.4
type AlternateServer struct {
IP net.IP
Port int
}
// AddTo adds ALTERNATE-SERVER attribute to message.
func (s *AlternateServer) AddTo(m *Message) error {
a := (*MappedAddress)(s)
return a.addAs(m, AttrAlternateServer)
}
// GetFrom decodes ALTERNATE-SERVER from message.
func (s *AlternateServer) GetFrom(m *Message) error {
a := (*MappedAddress)(s)
return a.getAs(m, AttrAlternateServer)
}
func (a MappedAddress) String() string {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
}
func (a *MappedAddress) getAs(m *Message, t AttrType) error {
v, err := m.Get(t)
if err != nil {
return err
}
family := bin.Uint16(v[0:2])
if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family),
)
}
ipLen := net.IPv4len
if family == familyIPv6 {
ipLen = net.IPv6len
}
// Ensuring len(a.IP) == ipLen and reusing a.IP.
if len(a.IP) < ipLen {
a.IP = a.IP[:cap(a.IP)]
for len(a.IP) < ipLen {
a.IP = append(a.IP, 0)
}
}
a.IP = a.IP[:ipLen]
for i := range a.IP {
a.IP[i] = 0
}
a.Port = int(bin.Uint16(v[2:4]))
copy(a.IP, v[4:])
return nil
}
func (a *MappedAddress) addAs(m *Message, t AttrType) error {
var (
family = familyIPv4
ip = a.IP
)
if len(a.IP) == net.IPv6len {
if isIPv4(ip) {
ip = ip[12:16] // like in ip.To4()
} else {
family = familyIPv6
}
} else if len(ip) != net.IPv4len {
return ErrBadIPLength
}
value := make([]byte, 128)
value[0] = 0 // first 8 bits are zeroes
bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port))
copy(value[4:], ip)
m.Add(t, value[:4+len(ip)])
return nil
}
// AddTo adds MAPPED-ADDRESS to message.
func (a *MappedAddress) AddTo(m *Message) error {
return a.addAs(m, AttrMappedAddress)
}
// GetFrom decodes MAPPED-ADDRESS from message.
func (a *MappedAddress) GetFrom(m *Message) error {
return a.getAs(m, AttrMappedAddress)
}

92
addr_test.go Normal file
View File

@@ -0,0 +1,92 @@
package stun
import (
"net"
"testing"
)
func TestMappedAddress(t *testing.T) {
m := new(Message)
addr := &MappedAddress{
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
}
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress)
if err := got.GetFrom(m); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); err != ErrAttributeNotFound {
t.Error("should be not found: ", err)
}
})
})
})
}
func TestAlternateServer(t *testing.T) {
m := new(Message)
addr := &AlternateServer{
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
}
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(AlternateServer)
if err := got.GetFrom(m); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); err != ErrAttributeNotFound {
t.Error("should be not found: ", err)
}
})
})
})
}
func BenchmarkMappedAddress_AddTo(b *testing.B) {
m := new(Message)
b.ReportAllocs()
addr := &MappedAddress{
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil {
b.Fatal(err)
}
m.Reset()
}
}
func BenchmarkAlternateServer_AddTo(b *testing.B) {
m := new(Message)
b.ReportAllocs()
addr := &AlternateServer{
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(m); err != nil {
b.Fatal(err)
}
m.Reset()
}
}

View File

@@ -8,8 +8,8 @@ import (
) )
const ( const (
familyIPv4 byte = 0x01 familyIPv4 uint16 = 0x01
familyIPv6 byte = 0x02 familyIPv6 uint16 = 0x02
) )
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. // XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
@@ -62,7 +62,7 @@ func (a *XORMappedAddress) AddTo(m *Message) error {
xorValue := make([]byte, net.IPv6len) xorValue := make([]byte, net.IPv6len)
copy(xorValue[4:], m.TransactionID[:]) copy(xorValue[4:], m.TransactionID[:])
bin.PutUint32(xorValue[0:4], magicCookie) bin.PutUint32(xorValue[0:4], magicCookie)
bin.PutUint16(value[0:2], uint16(family)) bin.PutUint16(value[0:2], family)
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
xorBytes(value[4:4+len(ip)], ip, xorValue) xorBytes(value[4:4+len(ip)], ip, xorValue)
m.Add(AttrXORMappedAddress, value[:4+len(ip)]) m.Add(AttrXORMappedAddress, value[:4+len(ip)])
@@ -95,7 +95,7 @@ func (a *XORMappedAddress) GetFrom(m *Message) error {
if err != nil { if err != nil {
return err return err
} }
family := byte(bin.Uint16(v[0:2])) family := bin.Uint16(v[0:2])
if family != familyIPv6 && family != familyIPv4 { if family != familyIPv6 && family != familyIPv4 {
return newDecodeErr("xor-mapped address", "family", return newDecodeErr("xor-mapped address", "family",
fmt.Sprintf("bad value %d", family), fmt.Sprintf("bad value %d", family),