Use MultiUDPMux to implement listen any address

In #475, import low-level API (ReadMsgUDP) to determine
destination interface for packets received by UDPConn listen
at any(unspecified) address, to fix msg received by incorrect
candidate that shared same ufrags. But the api has compatibility
issues, also not reliable in some special network cases like
AWS/ECS.
So this pr revert that change, and make UDPMuxDefault not
accept Conn listen at unspecified address. Also provide a
NewMultiUDPMuxFromPort helper function to create a MultiUDPMux
to listen at all addresses.
For ice gather, it will use UDPMux's listen address to generate
canidates instead of create it from interfaces.
This commit is contained in:
cnderrauber
2022-10-09 00:39:39 +08:00
committed by cnderrauber
parent af9281dc76
commit 04a6027e93
14 changed files with 499 additions and 518 deletions

View File

@@ -24,11 +24,14 @@ func TestMuxAgent(t *testing.T) {
const muxPort = 7686
c, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: muxPort,
})
require.NoError(t, err)
loggerFactory := logging.NewDefaultLoggerFactory()
udpMux := NewUDPMuxDefault(UDPMuxParams{
udpMux, err := NewUDPMuxDefault(UDPMuxParams{
Logger: loggerFactory.NewLogger("ice"),
UDPConn: c,
})

View File

@@ -134,11 +134,12 @@ var (
errMismatchUsername = errors.New("username mismatch")
errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages")
errUDPMuxDisabled = errors.New("UDPMux is not enabled")
errCandidateIPNotFound = errors.New("could not determine local IP for Mux candidate")
errNoXorAddrMapping = errors.New("no address mapping")
errSendSTUNPacket = errors.New("failed to send STUN packet")
errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr")
errNotImplemented = errors.New("not implemented yet")
errNoUDPMuxAvailable = errors.New("no UDP mux is available")
errNoTCPMuxAvailable = errors.New("no TCP mux is available")
errListenUnspecified = errors.New("can't listen on unspecified address")
errInvalidAddress = errors.New("invalid address")
)

View File

@@ -273,16 +273,14 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
return errUDPMuxDisabled
}
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes)
switch {
case err != nil:
return err
case len(localIPs) == 0:
return errCandidateIPNotFound
}
localAddresses := a.udpMux.GetListenAddresses()
for _, candidateIP := range localIPs {
localIP := candidateIP
for _, addr := range localAddresses {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return errInvalidAddress
}
candidateIP := udpAddr.IP
if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if mappedIP, innerErr := a.extIPMapper.findExternalIP(candidateIP.String()); innerErr != nil {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
@@ -292,31 +290,10 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
}
}
var conns []net.PacketConn
if multi, ok := a.udpMux.(AllConnsGetter); ok {
conns, err = multi.GetAllConns(a.localUfrag, candidateIP.To4() == nil, localIP)
conn, err := a.udpMux.GetConn(a.localUfrag, udpAddr)
if err != nil {
return err
}
if len(conns) == 0 {
a.log.Warnf("Failed to get any connections from MultiUDPMux for candidate: %s", candidateIP)
continue
}
} else {
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil, localIP)
if err != nil {
return err
}
conns = []net.PacketConn{conn}
}
for _, conn := range conns {
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host mux candidate: %s failed to cast", candidateIP))
continue
}
hostConfig := CandidateHostConfig{
Network: udp,
Address: candidateIP.String(),
@@ -339,7 +316,6 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
continue
}
}
}
return nil
}
@@ -414,8 +390,14 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
}
for i := range urls {
for _, listenAddr := range a.udpMuxSrflx.GetListenAddresses() {
udpAddr, ok := listenAddr.(*net.UDPAddr)
if !ok {
a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr")
continue
}
wg.Add(1)
go func(url URL, network string, isIPv6 bool) {
go func(url URL, network string, localAddr *net.UDPAddr) {
defer wg.Done()
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
@@ -431,7 +413,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return
}
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr)
if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v", network, url, err)
return
@@ -440,19 +422,13 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
ip := xoraddr.IP
port := xoraddr.Port
laddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create server reflexive candidate: %s %s %d: cast failed", network, ip, port))
return
}
srflxConfig := CandidateServerReflexiveConfig{
Network: network,
Address: ip.String(),
Port: port,
Component: ComponentRTP,
RelAddr: laddr.IP.String(),
RelPort: laddr.Port,
RelAddr: localAddr.IP.String(),
RelPort: localAddr.Port,
}
c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil {
@@ -466,7 +442,8 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
}
}(*urls[i], networkType.String(), networkType.IsIPv6())
}(*urls[i], networkType.String(), udpAddr)
}
}
}
}

View File

@@ -507,9 +507,9 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}()
expectedPorts = append(expectedPorts, port)
udpMuxInstances = append(udpMuxInstances, NewUDPMuxDefault(UDPMuxParams{
UDPConn: conn,
}))
muxDefault, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn})
assert.NoError(t, err)
udpMuxInstances = append(udpMuxInstances, muxDefault)
idx := i
defer func() {
_ = udpMuxInstances[idx].Close()
@@ -675,7 +675,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented
}
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getConnForURLTimes++
@@ -694,3 +694,7 @@ func (m *universalUDPMuxMock) RemoveConnByUfrag(ufrag string) {
defer m.mu.Unlock()
m.removeConnByUfragTimes++
}
func (m *universalUDPMuxMock) GetListenAddresses() []net.Addr {
return []net.Addr{m.conn.LocalAddr()}
}

View File

@@ -5,6 +5,14 @@ package ice
import "net"
// AllConnsGetter allows multiple fixed TCP ports to be used,
// each which is multiplexed like TCPMux. AllConnsGetter also acts as
// a TCPMux, in which case it will return a single connection for one
// of the ports.
type AllConnsGetter interface {
GetAllConns(ufrag string, isIPv6 bool, localIP net.IP) ([]net.PacketConn, error)
}
// MultiTCPMuxDefault implements both TCPMux and AllConnsGetter,
// allowing users to pass multiple TCPMux instances to the ICE agent
// configuration.

View File

@@ -15,8 +15,9 @@ import (
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error)
GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
GetListenAddresses() []net.Addr
}
// UDPMuxDefault is an implementation of the interface
@@ -26,13 +27,11 @@ type UDPMuxDefault struct {
closedChan chan struct{}
closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]map[ipAddr]*udpMuxedConn
// conns are maps of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
addressMapMu sync.RWMutex
// remote address (ip:port) -> (localip -> udpMuxedConn)
addressMap map[string]map[ipAddr]*udpMuxedConn
addressMap map[string]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
@@ -42,37 +41,28 @@ type UDPMuxDefault struct {
const maxAddrSize = 512
// UDPMuxConn is a udp PacketConn with ReadMsgUDP and File method
// to retrieve the destination local address of the received packet
type UDPMuxConn interface {
net.PacketConn
// ReadMsgUdp used to get destination address when received a udp packet
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
// File returns a copy of the underlying os.File.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
File() (f *os.File, err error)
}
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn UDPMuxConn
UDPConn net.PacketConn
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
func NewUDPMuxDefault(params UDPMuxParams) (*UDPMuxDefault, error) {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
return nil, errInvalidAddress
} else if ok && addr.IP.IsUnspecified() {
return nil, errListenUnspecified
}
m := &UDPMuxDefault{
addressMap: make(map[string]map[ipAddr]*udpMuxedConn),
addressMap: map[string]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]map[ipAddr]*udpMuxedConn),
connsIPv6: make(map[string]map[ipAddr]*udpMuxedConn),
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
@@ -84,7 +74,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
go m.connWorker()
return m
return m, nil
}
// LocalAddr returns the listening address of this UDPMuxDefault
@@ -92,9 +82,17 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
if m.params.UDPConn.LocalAddr() != addr {
return nil, errInvalidAddress
}
m.mu.Lock()
defer m.mu.Unlock()
@@ -102,60 +100,34 @@ func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.Pa
return nil, io.ErrClosedPipe
}
if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
if conn, ok := m.getConn(ufrag); ok {
return conn, nil
}
c, err := m.createMuxedConn(ufrag, local)
if err != nil {
return nil, err
}
c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.removeConnByUfragAndLocalHost(ufrag, local)
m.RemoveConnByUfrag(ufrag)
}()
var (
conns map[ipAddr]*udpMuxedConn
ok bool
)
if isIPv6 {
if conns, ok = m.connsIPv6[ufrag]; !ok {
conns = make(map[ipAddr]*udpMuxedConn)
m.connsIPv6[ufrag] = conns
}
} else {
if conns, ok = m.connsIPv4[ufrag]; !ok {
conns = make(map[ipAddr]*udpMuxedConn)
m.connsIPv4[ufrag] = conns
}
}
conns[ipAddr(local.String())] = c
m.conns[ufrag] = c
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 4)
var removedConn *udpMuxedConn
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if conns, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
for _, c := range conns {
removedConns = append(removedConns, c)
}
}
if conns, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
for _, c := range conns {
removedConns = append(removedConns, c)
}
if c, ok := m.conns[ufrag]; ok {
delete(m.conns, ufrag)
removedConn = c
}
m.mu.Unlock()
if len(removedConns) == 0 {
if removedConn == nil {
// No need to lock if no connection was found
return
}
@@ -163,65 +135,11 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
addresses := removedConn.getAddresses()
for _, addr := range addresses {
if conns, ok := m.addressMap[addr]; ok {
delete(conns, ipAddr(c.params.LocalIP.String()))
if len(conns) == 0 {
delete(m.addressMap, addr)
}
}
}
}
}
func (m *UDPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) {
removedConns := make([]*udpMuxedConn, 0, 4)
localIP := ipAddr(local.String())
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if conns, ok := m.connsIPv4[ufrag]; ok {
if conn, ok := conns[localIP]; ok {
delete(conns, localIP)
if len(conns) == 0 {
delete(m.connsIPv4, ufrag)
}
removedConns = append(removedConns, conn)
}
}
if conns, ok := m.connsIPv6[ufrag]; ok {
if conn, ok := conns[localIP]; ok {
delete(conns, localIP)
if len(conns) == 0 {
delete(m.connsIPv6, ufrag)
}
removedConns = append(removedConns, conn)
}
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if conns, ok := m.addressMap[addr]; ok {
delete(conns, ipAddr(c.params.LocalIP.String()))
if len(conns) == 0 {
delete(m.addressMap, addr)
}
}
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
@@ -240,40 +158,15 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for _, conns := range m.connsIPv4 {
for _, c := range conns {
for _, c := range m.conns {
_ = c.Close()
}
}
for _, conns := range m.connsIPv6 {
for _, c := range conns {
_ = c.Close()
}
}
m.connsIPv4 = make(map[string]map[ipAddr]*udpMuxedConn)
m.connsIPv6 = make(map[string]map[ipAddr]*udpMuxedConn)
m.conns = make(map[string]*udpMuxedConn)
// ReadMsgUDP will block until something is received, otherwise it will block forever
// and the Conn's Close method too. So send a packet to wake it for exit.
close(m.closedChan)
closeConn, errConn := net.DialUDP("udp", nil, m.params.UDPConn.LocalAddr().(*net.UDPAddr))
// i386 doesn't support dial local ipv6 address
if errConn != nil && strings.Contains(errConn.Error(), "dial udp [::]:") &&
strings.Contains(errConn.Error(), "connect: cannot assign requested address") {
closeConn, errConn = net.DialUDP("udp4", nil, &net.UDPAddr{Port: m.params.UDPConn.LocalAddr().(*net.UDPAddr).Port})
}
if errConn != nil {
m.params.Logger.Errorf("Failed to open close notify socket, %v", errConn)
} else {
defer func() {
_ = closeConn.Close()
}()
_, errConn = closeConn.Write([]byte("close"))
if errConn != nil {
m.params.Logger.Errorf("Failed to send close notify msg, %v", errConn)
}
}
_ = m.params.UDPConn.Close()
})
return err
}
@@ -290,58 +183,36 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
conns, ok := m.addressMap[addr]
if ok {
existing, ok := conns[ipAddr(conn.params.LocalIP.String())]
existing, ok := m.addressMap[addr]
if ok {
existing.removeAddress(addr)
}
} else {
conns = make(map[ipAddr]*udpMuxedConn)
m.addressMap[addr] = conns
}
conns[ipAddr(conn.params.LocalIP.String())] = conn
m.addressMap[addr] = conn
m.params.Logger.Debugf("Registered %s for %s, local %s", addr, conn.params.Key, conn.params.LocalIP.String())
m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string, local net.IP) (*udpMuxedConn, error) {
m.params.Logger.Debugf("Creating new muxed connection, key:%s local:%s ", key, local.String())
addr, ok := m.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, ErrGetTransportAddress
}
localAddr := *addr
localAddr.IP = local
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: &localAddr,
LocalIP: local,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c, nil
return c
}
func (m *UDPMuxDefault) connWorker() { //nolint:gocognit
func (m *UDPMuxDefault) connWorker() {
logger := m.params.Logger
defer func() {
_ = m.Close()
}()
localUDPAddr, _ := m.LocalAddr().(*net.UDPAddr)
buf := make([]byte, receiveMTU)
file, _ := m.params.UDPConn.File()
setUDPSocketOptionsForLocalAddr(file.Fd(), m.params.Logger)
_ = file.Close()
oob := make([]byte, receiveMTU)
for {
localHost := localUDPAddr.IP
n, oobn, _, addr, err := m.params.UDPConn.ReadMsgUDP(buf, oob)
n, addr, err := m.params.UDPConn.ReadFrom(buf)
if m.IsClosed() {
return
} else if err != nil {
@@ -354,29 +225,19 @@ func (m *UDPMuxDefault) connWorker() { //nolint:gocognit
return
}
// get destination local addr from received packet
if oobIP, addrErr := getLocalAddrFromOob(oob[:oobn]); addrErr == nil {
localHost = oobIP
} else {
m.params.Logger.Warnf("could not get local addr from oob: %v, remote %s", addrErr, addr)
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
logger.Errorf("underlying PacketConn did not return a UDPAddr")
return
}
// If we have already seen this address dispatch to the appropriate destination
var destinationConn *udpMuxedConn
m.addressMapMu.Lock()
if conns, ok := m.addressMap[addr.String()]; ok {
destinationConn, ok = conns[ipAddr(localHost.String())]
if !ok {
for _, c := range conns {
destinationConn = c
break
}
}
}
destinationConn := m.addressMap[addr.String()]
m.addressMapMu.Unlock()
// If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(buf[:n]) && !localHost.IsUnspecified() {
if destinationConn == nil && stun.IsMessage(buf[:n]) {
msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...),
}
@@ -393,34 +254,25 @@ func (m *UDPMuxDefault) connWorker() { //nolint:gocognit
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := addr.IP.To4() == nil
m.mu.Lock()
destinationConn, _ = m.getConn(ufrag, isIPv6, localHost)
destinationConn, _ = m.getConn(ufrag)
m.mu.Unlock()
}
if destinationConn == nil {
m.params.Logger.Tracef("dropping packet from %s", addr.String())
m.params.Logger.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String())
continue
}
if err = destinationConn.writePacket(buf[:n], addr); err != nil {
if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
m.params.Logger.Errorf("could not write packet: %v", err)
}
}
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *udpMuxedConn, ok bool) {
var conns map[ipAddr]*udpMuxedConn
if isIPv6 {
conns, ok = m.connsIPv6[ufrag]
} else {
conns, ok = m.connsIPv4[ufrag]
}
if conns != nil {
val, ok = conns[ipAddr(local.String())]
}
func (m *UDPMuxDefault) getConn(ufrag string) (val *udpMuxedConn, ok bool) {
val, ok = m.conns[ufrag]
return
}

View File

@@ -3,42 +3,44 @@
//nolint:dupl
package ice
import "net"
import (
"net"
// AllConnsGetter allows multiple fixed UDP or TCP ports to be used,
// each which is multiplexed like UDPMux. AllConnsGetter also acts as
// a UDPMux, in which case it will return a single connection for one
// of the ports.
type AllConnsGetter interface {
GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error)
}
"github.com/pion/logging"
"github.com/pion/transport/vnet"
)
// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter,
// allowing users to pass multiple UDPMux instances to the ICE agent
// configuration.
type MultiUDPMuxDefault struct {
muxs []UDPMux
localAddrToMux map[string]UDPMux
}
// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances.
func NewMultiUDPMuxDefault(muxs ...UDPMux) *MultiUDPMuxDefault {
addrToMux := make(map[string]UDPMux)
for _, mux := range muxs {
for _, addr := range mux.GetListenAddresses() {
addrToMux[addr.String()] = mux
}
}
return &MultiUDPMuxDefault{
muxs: muxs,
localAddrToMux: addrToMux,
}
}
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found. This, unlike
// GetAllConns, will only return a single PacketConn from the first
// mux that was passed in to NewMultiUDPMuxDefault.
func (m *MultiUDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) {
// NOTE: We always use the first element here in order to maintain the
// behavior of using an existing connection if one exists.
if len(m.muxs) == 0 {
// creates the connection if an existing one can't be found.
func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
mux, ok := m.localAddrToMux[addr.String()]
if !ok {
return nil, errNoUDPMuxAvailable
}
return m.muxs[0].GetConn(ufrag, isIPv6, local)
return mux.GetConn(ufrag, addr)
}
// RemoveConnByUfrag stops and removes the muxed packet connection
@@ -49,26 +51,6 @@ func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
}
// GetAllConns returns a PacketConn for each underlying UDPMux
func (m *MultiUDPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) {
if len(m.muxs) == 0 {
// Make sure that we either return at least one connection or an error.
return nil, errNoUDPMuxAvailable
}
var conns []net.PacketConn
for _, mux := range m.muxs {
conn, err := mux.GetConn(ufrag, isIPv6, local)
if err != nil {
// For now, this implementation is all or none.
return nil, err
}
if conn != nil {
conns = append(conns, conn)
}
}
return conns, nil
}
// Close the multi mux, no further connections could be created
func (m *MultiUDPMuxDefault) Close() error {
var err error
@@ -79,3 +61,146 @@ func (m *MultiUDPMuxDefault) Close() error {
}
return err
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxs {
addrs = append(addrs, mux.GetListenAddresses()...)
}
return addrs
}
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
}
for _, opt := range opts {
opt.apply(&params)
}
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
if err != nil {
return nil, err
}
conns := make([]net.PacketConn, 0, len(ips))
for _, ip := range ips {
conn, listenErr := net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
}
conns = append(conns, conn)
}
if err != nil {
for _, conn := range conns {
_ = conn.Close()
}
return nil, err
}
muxs := make([]UDPMux, 0, len(conns))
for _, conn := range conns {
mux, muxErr := NewUDPMuxDefault(UDPMuxParams{Logger: params.logger, UDPConn: conn})
if muxErr != nil {
err = muxErr
break
}
muxs = append(muxs, mux)
}
if err != nil {
for _, mux := range muxs {
_ = mux.Close()
}
return nil, err
}
return NewMultiUDPMuxDefault(muxs...), nil
}
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort
type UDPMuxFromPortOption interface {
apply(*multiUDPMuxFromPortParam)
}
type multiUDPMuxFromPortParam struct {
ifFilter func(string) bool
ipFilter func(ip net.IP) bool
networks []NetworkType
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
}
type udpMuxFromPortOption struct {
f func(*multiUDPMuxFromPortParam)
}
func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) {
o.f(p)
}
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used
func UDPMuxFromPortWithInterfaceFilter(f func(string) bool) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.ifFilter = f
},
}
}
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used
func UDPMuxFromPortWithIPFilter(f func(ip net.IP) bool) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.ipFilter = f
},
}
}
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6
func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.networks = networks
},
}
}
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size
func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.readBufferSize = size
},
}
}
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size
func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.writeBufferSize = size
},
}
}
// UDPMuxFromPortWithLogger set the logger for the created UDPMux
func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.logger = logger
},
}
}

View File

@@ -5,6 +5,7 @@ package ice
import (
"net"
"strings"
"sync"
"testing"
"time"
@@ -20,16 +21,32 @@ func TestMultiUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
conn1, err := net.ListenUDP(udp, &net.UDPAddr{})
conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
conn2, err := net.ListenUDP(udp, &net.UDPAddr{})
conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMuxMulti := NewMultiUDPMuxDefault(
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}),
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}),
)
conn3, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
if err != nil {
// ipv6 is not supported on this machine
t.Log("ipv6 is not supported on this machine")
}
muxes := []UDPMux{}
muxV41, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1})
require.NoError(t, err)
muxes = append(muxes, muxV41)
muxV42, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2})
require.NoError(t, err)
muxes = append(muxes, muxV42)
if conn3 != nil {
muxV6, v6err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn3})
require.NoError(t, v6err)
muxes = append(muxes, muxV6)
}
udpMuxMulti := NewMultiUDPMuxDefault(muxes...)
defer func() {
_ = udpMuxMulti.Close()
_ = conn1.Close()
@@ -60,32 +77,77 @@ func TestMultiUDPMux(t *testing.T) {
require.NoError(t, udpMuxMulti.Close())
// can't create more connections
_, err = udpMuxMulti.GetConn("failufrag", false, net.IP{})
_, err = udpMuxMulti.GetConn("failufrag", conn1.LocalAddr())
require.Error(t, err)
}
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
pktConns, err := udpMuxMulti.GetAllConns(ufrag, false, net.IP{127, 0, 0, 1})
addrs := udpMuxMulti.GetListenAddresses()
pktConns := make([]net.PacketConn, 0, len(addrs))
for _, addr := range addrs {
udpAddr, ok := addr.(*net.UDPAddr)
require.True(t, ok)
if network == "udp4" && udpAddr.IP.To4() == nil {
continue
} else if network == "udp6" && udpAddr.IP.To4() != nil {
continue
}
c, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
pktConns = append(pktConns, c)
}
defer func() {
for _, c := range pktConns {
_ = c.Close()
}
}()
require.Len(t, pktConns, len(udpMuxMulti.muxs), "there should be a PacketConn for every mux")
// Try talking with each PacketConn
for i, pktConn := range pktConns {
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{
Port: pktConn.LocalAddr().(*net.UDPAddr).Port,
})
for _, pktConn := range pktConns {
remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr))
require.NoError(t, err, "error dialing test udp connection")
localConn, err := udpMuxMulti.muxs[i].GetConn(ufrag, false, remoteConn.RemoteAddr().(*net.UDPAddr).IP)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
}
}
func TestUnspecifiedUDPMux(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
muxPort := 7778
udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
return !strings.Contains(s, "docker")
}))
require.NoError(t, err)
require.GreaterOrEqual(t, len(udpMuxMulti.muxs), 1, "at least have 1 muxs")
defer func() {
_ = pktConn.Close()
_ = udpMuxMulti.Close()
}()
testMuxConnectionPair(t, localConn, remoteConn, ufrag)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", "udp4")
}()
// skip ipv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", "udp6")
}
wg.Wait()
require.NoError(t, udpMuxMulti.Close())
}

View File

@@ -25,10 +25,21 @@ func TestUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
conn4, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUDPMuxDefault(UDPMuxParams{
conn6, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
if err != nil {
t.Log("IPv6 is not supported on this machine")
}
for network, c := range map[string]net.PacketConn{"udp4": conn4, "udp6": conn6} {
if c == nil {
continue
}
conn := c
t.Run(network, func(t *testing.T) {
udpMux, err := NewUDPMuxDefault(UDPMuxParams{
Logger: nil,
UDPConn: conn,
})
@@ -49,16 +60,11 @@ func TestUDPMux(t *testing.T) {
defer wg.Done()
testMuxConnection(t, udpMux, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMuxConnection(t, udpMux, "ufrag2", "udp4")
}()
// skip ipv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMuxConnection(t, udpMux, "ufrag3", "udp6")
if ptrSize != 32 || network != "udp6" {
testMuxConnection(t, udpMux, "ufrag2", network)
}
wg.Wait()
@@ -66,8 +72,24 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close())
// can't create more connections
_, err = udpMux.GetConn("failufrag", false, net.IPv4zero)
_, err = udpMux.GetConn("failufrag", udpMux.LocalAddr())
require.Error(t, err)
})
}
}
func TestCantMuxUnspecifiedAddr(t *testing.T) {
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
require.NoError(t, err)
_, err = NewUDPMuxDefault(UDPMuxParams{
Logger: nil,
UDPConn: conn,
})
require.Equal(t, errListenUnspecified, err)
_ = conn.Close()
}
func TestAddressEncoding(t *testing.T) {
@@ -111,17 +133,15 @@ func TestAddressEncoding(t *testing.T) {
}
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{
Port: udpMux.LocalAddr().(*net.UDPAddr).Port,
})
require.NoError(t, err, "error dialing test udp connection")
pktConn, err := udpMux.GetConn(ufrag, false, remoteConn.RemoteAddr().(*net.UDPAddr).IP)
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
}()
remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr))
require.NoError(t, err, "error dialing test udp connection")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
}

View File

@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error)
}
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
@@ -33,12 +33,12 @@ type UniversalUDPMuxDefault struct {
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn UDPMuxConn
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) (*UniversalUDPMuxDefault, error) {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
@@ -54,7 +54,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
UDPMuxConn: params.UDPConn,
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
}
@@ -64,14 +64,18 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
Logger: params.Logger,
UDPConn: m.params.UDPConn,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
muxDefault, err := NewUDPMuxDefault(udpMuxParams)
if err != nil {
return nil, err
}
m.UDPMuxDefault = muxDefault
return m
return m, nil
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
UDPMuxConn
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
}
@@ -84,48 +88,44 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6, net.IPv4zero)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
}
// ReadMsgUDP is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
// It passes processed packets further to the UDPMux (maybe this is not really necessary).
func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) {
n, oobn, flags, addr, err = c.UDPMuxConn.ReadMsgUDP(b, oob)
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil {
return
}
if stun.IsMessage(b[:n]) {
bytes := make([]byte, n)
copy(bytes, b[:n])
if stun.IsMessage(p[:n]) {
msg := &stun.Message{
Raw: bytes,
Raw: append([]byte{}, p[:n]...),
}
if err = msg.Decode(); err != nil {
c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
err = nil
return n, addr, nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return
}
if c.mux.isXORMappedResponse(msg, addr.String()) {
err = c.mux.handleXORMappedResponse(addr, msg)
if c.mux.isXORMappedResponse(msg, udpAddr.String()) {
err = c.mux.handleXORMappedResponse(udpAddr, msg)
if err != nil {
c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err)
err = nil
return
return n, addr, nil
}
return
}
}
return
}
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
oob := make([]byte, 100)
n, _, _, addr, err = c.ReadMsgUDP(p, oob)
return
return n, addr, err
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -14,10 +14,10 @@ import (
)
func TestUniversalUDPMux(t *testing.T) {
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{
udpMux, err := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{
Logger: nil,
UDPConn: conn,
})
@@ -41,7 +41,7 @@ func TestUniversalUDPMux(t *testing.T) {
}
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag, false, net.IPv4zero)
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()

View File

@@ -16,7 +16,6 @@ type udpMuxedConnParams struct {
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
LocalIP net.IP
Logger logging.LeveledLogger
}

View File

@@ -1,51 +0,0 @@
//go:build !js && !windows
package ice
import (
"bytes"
"encoding/binary"
"errors"
"net"
"syscall"
"github.com/pion/logging"
)
var errUnknownOobData = errors.New("unknown oob data")
func setUDPSocketOptionsForLocalAddr(fd uintptr, logger logging.LeveledLogger) {
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_2292PKTINFO, 1); err != nil {
logger.Warnf("Failed to set sockopt IPV6_2292PKTINFO: %s", err)
}
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1); err != nil {
logger.Warnf("Failed to set sockopt IP_PKTINFO: %s", err)
}
}
func getLocalAddrFromOob(oob []byte) (net.IP, error) {
var localHost net.IP
// get destination local addr from received packet
oobBuffer := bytes.NewBuffer(oob)
msg := syscall.Cmsghdr{}
err := binary.Read(oobBuffer, binary.LittleEndian, &msg)
if err == nil {
switch {
case msg.Level == syscall.IPPROTO_IP && msg.Type == syscall.IP_PKTINFO:
packetInfo := syscall.Inet4Pktinfo{}
if err = binary.Read(oobBuffer, binary.LittleEndian, &packetInfo); err == nil {
localHost = net.IP(packetInfo.Addr[:])
return localHost, nil
}
case msg.Level == syscall.IPPROTO_IPV6 && msg.Type == syscall.IPV6_2292PKTINFO:
packetInfo := syscall.Inet6Pktinfo{}
if err = binary.Read(oobBuffer, binary.LittleEndian, &packetInfo); err == nil {
localHost = net.IP(packetInfo.Addr[:])
return localHost, nil
}
default:
return localHost, errUnknownOobData
}
}
return localHost, err
}

View File

@@ -1,19 +0,0 @@
//go:build js || windows
package ice
import (
"errors"
"net"
"github.com/pion/logging"
)
var errUnsupported = errors.New("unsupported")
func setUDPSocketOptionsForLocalAddr(fd uintptr, logger logging.LeveledLogger) {
}
func getLocalAddrFromOob(oob []byte) (net.IP, error) {
return nil, errUnsupported
}