mirror of
https://github.com/pion/stun.git
synced 2025-09-27 20:22:08 +08:00
a:addr: implemented with ALTERNATE-SERVER
This commit is contained in:
105
addr.go
Normal file
105
addr.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// MappedAddress represents MAPPED-ADDRESS attribute.
|
||||
//
|
||||
// This attribute is used only by servers for achieving backwards
|
||||
// compatibility with RFC 3489 clients.
|
||||
// https://tools.ietf.org/html/rfc5389#section-15.4
|
||||
type MappedAddress struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
// AlternateServer represents ALTERNATE-SERVER attribute.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc5389#section-15.4
|
||||
type AlternateServer struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
// AddTo adds ALTERNATE-SERVER attribute to message.
|
||||
func (s *AlternateServer) AddTo(m *Message) error {
|
||||
a := (*MappedAddress)(s)
|
||||
return a.addAs(m, AttrAlternateServer)
|
||||
}
|
||||
|
||||
// GetFrom decodes ALTERNATE-SERVER from message.
|
||||
func (s *AlternateServer) GetFrom(m *Message) error {
|
||||
a := (*MappedAddress)(s)
|
||||
return a.getAs(m, AttrAlternateServer)
|
||||
}
|
||||
|
||||
func (a MappedAddress) String() string {
|
||||
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
|
||||
}
|
||||
|
||||
func (a *MappedAddress) getAs(m *Message, t AttrType) error {
|
||||
v, err := m.Get(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
family := bin.Uint16(v[0:2])
|
||||
if family != familyIPv6 && family != familyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
)
|
||||
}
|
||||
ipLen := net.IPv4len
|
||||
if family == familyIPv6 {
|
||||
ipLen = net.IPv6len
|
||||
}
|
||||
// Ensuring len(a.IP) == ipLen and reusing a.IP.
|
||||
if len(a.IP) < ipLen {
|
||||
a.IP = a.IP[:cap(a.IP)]
|
||||
for len(a.IP) < ipLen {
|
||||
a.IP = append(a.IP, 0)
|
||||
}
|
||||
}
|
||||
a.IP = a.IP[:ipLen]
|
||||
for i := range a.IP {
|
||||
a.IP[i] = 0
|
||||
}
|
||||
a.Port = int(bin.Uint16(v[2:4]))
|
||||
copy(a.IP, v[4:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *MappedAddress) addAs(m *Message, t AttrType) error {
|
||||
var (
|
||||
family = familyIPv4
|
||||
ip = a.IP
|
||||
)
|
||||
if len(a.IP) == net.IPv6len {
|
||||
if isIPv4(ip) {
|
||||
ip = ip[12:16] // like in ip.To4()
|
||||
} else {
|
||||
family = familyIPv6
|
||||
}
|
||||
} else if len(ip) != net.IPv4len {
|
||||
return ErrBadIPLength
|
||||
}
|
||||
value := make([]byte, 128)
|
||||
value[0] = 0 // first 8 bits are zeroes
|
||||
bin.PutUint16(value[0:2], family)
|
||||
bin.PutUint16(value[2:4], uint16(a.Port))
|
||||
copy(value[4:], ip)
|
||||
m.Add(t, value[:4+len(ip)])
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTo adds MAPPED-ADDRESS to message.
|
||||
func (a *MappedAddress) AddTo(m *Message) error {
|
||||
return a.addAs(m, AttrMappedAddress)
|
||||
}
|
||||
|
||||
// GetFrom decodes MAPPED-ADDRESS from message.
|
||||
func (a *MappedAddress) GetFrom(m *Message) error {
|
||||
return a.getAs(m, AttrMappedAddress)
|
||||
}
|
92
addr_test.go
Normal file
92
addr_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMappedAddress(t *testing.T) {
|
||||
m := new(Message)
|
||||
addr := &MappedAddress{
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
}
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(MappedAddress)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); err != ErrAttributeNotFound {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlternateServer(t *testing.T) {
|
||||
m := new(Message)
|
||||
addr := &AlternateServer{
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
}
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(AlternateServer)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); err != ErrAttributeNotFound {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkMappedAddress_AddTo(b *testing.B) {
|
||||
m := new(Message)
|
||||
b.ReportAllocs()
|
||||
addr := &MappedAddress{
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAlternateServer_AddTo(b *testing.B) {
|
||||
m := new(Message)
|
||||
b.ReportAllocs()
|
||||
addr := &AlternateServer{
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
m.Reset()
|
||||
}
|
||||
}
|
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
familyIPv4 byte = 0x01
|
||||
familyIPv6 byte = 0x02
|
||||
familyIPv4 uint16 = 0x01
|
||||
familyIPv6 uint16 = 0x02
|
||||
)
|
||||
|
||||
// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute.
|
||||
@@ -62,7 +62,7 @@ func (a *XORMappedAddress) AddTo(m *Message) error {
|
||||
xorValue := make([]byte, net.IPv6len)
|
||||
copy(xorValue[4:], m.TransactionID[:])
|
||||
bin.PutUint32(xorValue[0:4], magicCookie)
|
||||
bin.PutUint16(value[0:2], uint16(family))
|
||||
bin.PutUint16(value[0:2], family)
|
||||
bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16))
|
||||
xorBytes(value[4:4+len(ip)], ip, xorValue)
|
||||
m.Add(AttrXORMappedAddress, value[:4+len(ip)])
|
||||
@@ -95,7 +95,7 @@ func (a *XORMappedAddress) GetFrom(m *Message) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
family := byte(bin.Uint16(v[0:2]))
|
||||
family := bin.Uint16(v[0:2])
|
||||
if family != familyIPv6 && family != familyIPv4 {
|
||||
return newDecodeErr("xor-mapped address", "family",
|
||||
fmt.Sprintf("bad value %d", family),
|
||||
|
Reference in New Issue
Block a user