Makes UDPMux IPv4/IPv6 aware

UDPMux before only worked with UDP4 traffic.
UDP6 traffic would simply be ignored.

This commit implements 2 connections per ufrag. When requesting a
connection for a ufrag the user must specify if they want IPv4 or IPv6.

Relates to pion/webrtc#1915
This commit is contained in:
Antoine Baché
2022-03-02 15:09:12 +01:00
committed by Antoine Baché
parent 427ac0fddb
commit 45ff379fd3
6 changed files with 72 additions and 39 deletions

View File

@@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
return errUDPMuxDisabled return errUDPMuxDisabled
} }
localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
switch { switch {
case err != nil: case err != nil:
return err return err
@@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
} }
} }
conn, err := a.udpMux.GetConn(a.localUfrag) conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
if err != nil { if err != nil {
return err return err
} }
@@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
for i := range urls { for i := range urls {
wg.Add(1) wg.Add(1)
go func(url URL, network string) { go func(url URL, network string, isIPv6 bool) {
defer wg.Done() defer wg.Done()
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
@@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return return
} }
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String()) conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
if err != nil { if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err) a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
return return
@@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
} }
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err) a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
} }
}(*urls[i], networkType.String()) }(*urls[i], networkType.String(), networkType.IsIPv6())
} }
} }
} }

View File

@@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented return nil, errNotImplemented
} }
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.getConnForURLTimes++ m.getConnForURLTimes++

View File

@@ -14,7 +14,7 @@ import (
// UDPMux allows multiple connections to go over a single UDP port // UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface { type UDPMux interface {
io.Closer io.Closer
GetConn(ufrag string) (net.PacketConn, error) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string) RemoveConnByUfrag(ufrag string)
} }
@@ -25,8 +25,8 @@ type UDPMuxDefault struct {
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn addressMap map[string]*udpMuxedConn
@@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{ m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{}, addressMap: map[string]*udpMuxedConn{},
params: params, params: params,
conns: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
@@ -76,7 +77,7 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network // GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -84,8 +85,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
} }
if c, ok := m.conns[ufrag]; ok { if conn, ok := m.getConn(ufrag, isIPv6); ok {
return c, nil return conn, nil
} }
c := m.createMuxedConn(ufrag) c := m.createMuxedConn(ufrag)
@@ -93,26 +94,30 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
<-c.CloseChannel() <-c.CloseChannel()
m.removeConn(ufrag) m.removeConn(ufrag)
}() }()
m.conns[ufrag] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil return c, nil
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock() removedConns := make([]*udpMuxedConn, 0, 2)
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
c := m.conns[key] // Keep lock section small to avoid deadlock with conn lock
delete(m.conns, key) m.mu.Lock()
if c != nil { if c, ok := m.connsIPv4[ufrag]; ok {
removedConns = append(removedConns, c) delete(m.connsIPv4, ufrag)
} removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
} }
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock() m.mu.Unlock()
m.addressMapMu.Lock() m.addressMapMu.Lock()
@@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
for _, c := range m.conns { for _, c := range m.connsIPv4 {
_ = c.Close() _ = c.Close()
} }
m.conns = make(map[string]*udpMuxedConn) for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan) close(m.closedChan)
}) })
return err return err
} }
func (m *UDPMuxDefault) removeConn(key string) { func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock // keep lock section small to avoid deadlock with conn lock
m.mu.Unlock() c := func() *udpMuxedConn {
m.mu.Lock()
defer m.mu.Unlock()
if c, ok := m.connsIPv4[key]; ok {
delete(m.connsIPv4, key)
return c
}
if c, ok := m.connsIPv6[key]; ok {
delete(m.connsIPv6, key)
return c
}
return nil
}()
if c == nil { if c == nil {
return return
@@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
} }
ufrag := strings.Split(string(attr), ":")[0] ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := udpAddr.IP.To4() == nil
m.mu.Lock() m.mu.Lock()
destinationConn = m.conns[ufrag] destinationConn, _ = m.getConn(ufrag, isIPv6)
m.mu.Unlock() m.mu.Unlock()
} }
@@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
} }
} }
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct { type bufferHolder struct {
buffer []byte buffer []byte
} }

View File

@@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close()) require.NoError(t, udpMux.Close())
// can't create more connections // can't create more connections
_, err = udpMux.GetConn("failufrag") _, err = udpMux.GetConn("failufrag", false)
require.Error(t, err) require.Error(t, err)
} }
@@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
} }
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag) pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {
_ = pktConn.Close() _ = pktConn.Close()

View File

@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
} }
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
@@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
} }
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. // ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.

View File

@@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
} }
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag) pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {
_ = pktConn.Close() _ = pktConn.Close()