diff --git a/addr.go b/addr.go new file mode 100644 index 0000000..4fefe58 --- /dev/null +++ b/addr.go @@ -0,0 +1,105 @@ +package stun + +import ( + "fmt" + "net" + "strconv" +) + +// MappedAddress represents MAPPED-ADDRESS attribute. +// +// This attribute is used only by servers for achieving backwards +// compatibility with RFC 3489 clients. +// https://tools.ietf.org/html/rfc5389#section-15.4 +type MappedAddress struct { + IP net.IP + Port int +} + +// AlternateServer represents ALTERNATE-SERVER attribute. +// +// https://tools.ietf.org/html/rfc5389#section-15.4 +type AlternateServer struct { + IP net.IP + Port int +} + +// AddTo adds ALTERNATE-SERVER attribute to message. +func (s *AlternateServer) AddTo(m *Message) error { + a := (*MappedAddress)(s) + return a.addAs(m, AttrAlternateServer) +} + +// GetFrom decodes ALTERNATE-SERVER from message. +func (s *AlternateServer) GetFrom(m *Message) error { + a := (*MappedAddress)(s) + return a.getAs(m, AttrAlternateServer) +} + +func (a MappedAddress) String() string { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) +} + +func (a *MappedAddress) getAs(m *Message, t AttrType) error { + v, err := m.Get(t) + if err != nil { + return err + } + family := bin.Uint16(v[0:2]) + if family != familyIPv6 && family != familyIPv4 { + return newDecodeErr("xor-mapped address", "family", + fmt.Sprintf("bad value %d", family), + ) + } + ipLen := net.IPv4len + if family == familyIPv6 { + ipLen = net.IPv6len + } + // Ensuring len(a.IP) == ipLen and reusing a.IP. + if len(a.IP) < ipLen { + a.IP = a.IP[:cap(a.IP)] + for len(a.IP) < ipLen { + a.IP = append(a.IP, 0) + } + } + a.IP = a.IP[:ipLen] + for i := range a.IP { + a.IP[i] = 0 + } + a.Port = int(bin.Uint16(v[2:4])) + copy(a.IP, v[4:]) + return nil +} + +func (a *MappedAddress) addAs(m *Message, t AttrType) error { + var ( + family = familyIPv4 + ip = a.IP + ) + if len(a.IP) == net.IPv6len { + if isIPv4(ip) { + ip = ip[12:16] // like in ip.To4() + } else { + family = familyIPv6 + } + } else if len(ip) != net.IPv4len { + return ErrBadIPLength + } + value := make([]byte, 128) + value[0] = 0 // first 8 bits are zeroes + bin.PutUint16(value[0:2], family) + bin.PutUint16(value[2:4], uint16(a.Port)) + copy(value[4:], ip) + m.Add(t, value[:4+len(ip)]) + return nil +} + +// AddTo adds MAPPED-ADDRESS to message. +func (a *MappedAddress) AddTo(m *Message) error { + return a.addAs(m, AttrMappedAddress) +} + +// GetFrom decodes MAPPED-ADDRESS from message. +func (a *MappedAddress) GetFrom(m *Message) error { + return a.getAs(m, AttrMappedAddress) +} diff --git a/addr_test.go b/addr_test.go new file mode 100644 index 0000000..3ce9be6 --- /dev/null +++ b/addr_test.go @@ -0,0 +1,92 @@ +package stun + +import ( + "net" + "testing" +) + +func TestMappedAddress(t *testing.T) { + m := new(Message) + addr := &MappedAddress{ + IP: net.ParseIP("122.12.34.5"), + Port: 5412, + } + t.Run("AddTo", func(t *testing.T) { + if err := addr.AddTo(m); err != nil { + t.Error(err) + } + t.Run("GetFrom", func(t *testing.T) { + got := new(MappedAddress) + if err := got.GetFrom(m); err != nil { + t.Error(err) + } + if !got.IP.Equal(addr.IP) { + t.Error("got bad IP: ", got.IP) + } + t.Run("Not found", func(t *testing.T) { + message := new(Message) + if err := got.GetFrom(message); err != ErrAttributeNotFound { + t.Error("should be not found: ", err) + } + }) + }) + }) +} + +func TestAlternateServer(t *testing.T) { + m := new(Message) + addr := &AlternateServer{ + IP: net.ParseIP("122.12.34.5"), + Port: 5412, + } + t.Run("AddTo", func(t *testing.T) { + if err := addr.AddTo(m); err != nil { + t.Error(err) + } + t.Run("GetFrom", func(t *testing.T) { + got := new(AlternateServer) + if err := got.GetFrom(m); err != nil { + t.Error(err) + } + if !got.IP.Equal(addr.IP) { + t.Error("got bad IP: ", got.IP) + } + t.Run("Not found", func(t *testing.T) { + message := new(Message) + if err := got.GetFrom(message); err != ErrAttributeNotFound { + t.Error("should be not found: ", err) + } + }) + }) + }) +} + +func BenchmarkMappedAddress_AddTo(b *testing.B) { + m := new(Message) + b.ReportAllocs() + addr := &MappedAddress{ + IP: net.ParseIP("122.12.34.5"), + Port: 5412, + } + for i := 0; i < b.N; i++ { + if err := addr.AddTo(m); err != nil { + b.Fatal(err) + } + m.Reset() + } +} + +func BenchmarkAlternateServer_AddTo(b *testing.B) { + m := new(Message) + b.ReportAllocs() + addr := &AlternateServer{ + IP: net.ParseIP("122.12.34.5"), + Port: 5412, + } + for i := 0; i < b.N; i++ { + if err := addr.AddTo(m); err != nil { + b.Fatal(err) + } + m.Reset() + } +} diff --git a/xoraddr.go b/xoraddr.go index 112685d..eeda31f 100644 --- a/xoraddr.go +++ b/xoraddr.go @@ -8,8 +8,8 @@ import ( ) const ( - familyIPv4 byte = 0x01 - familyIPv6 byte = 0x02 + familyIPv4 uint16 = 0x01 + familyIPv6 uint16 = 0x02 ) // XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. @@ -62,7 +62,7 @@ func (a *XORMappedAddress) AddTo(m *Message) error { xorValue := make([]byte, net.IPv6len) copy(xorValue[4:], m.TransactionID[:]) bin.PutUint32(xorValue[0:4], magicCookie) - bin.PutUint16(value[0:2], uint16(family)) + bin.PutUint16(value[0:2], family) bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) xorBytes(value[4:4+len(ip)], ip, xorValue) m.Add(AttrXORMappedAddress, value[:4+len(ip)]) @@ -95,7 +95,7 @@ func (a *XORMappedAddress) GetFrom(m *Message) error { if err != nil { return err } - family := byte(bin.Uint16(v[0:2])) + family := bin.Uint16(v[0:2]) if family != familyIPv6 && family != familyIPv4 { return newDecodeErr("xor-mapped address", "family", fmt.Sprintf("bad value %d", family),