Makes UDPMux IPv4/IPv6 aware

UDPMux before only worked with UDP4 traffic.
UDP6 traffic would simply be ignored.

This commit implements 2 connections per ufrag. When requesting a
connection for a ufrag the user must specify if they want IPv4 or IPv6.

Relates to pion/webrtc#1915
This commit is contained in:
Antoine Baché
2022-03-02 15:09:12 +01:00
committed by Antoine Baché
parent 427ac0fddb
commit 45ff379fd3
6 changed files with 72 additions and 39 deletions

View File

@@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
return errUDPMuxDisabled
}
localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
switch {
case err != nil:
return err
@@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
}
}
conn, err := a.udpMux.GetConn(a.localUfrag)
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
if err != nil {
return err
}
@@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
for i := range urls {
wg.Add(1)
go func(url URL, network string) {
go func(url URL, network string, isIPv6 bool) {
defer wg.Done()
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
@@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return
}
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String())
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
return
@@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
}
}(*urls[i], networkType.String())
}(*urls[i], networkType.String(), networkType.IsIPv6())
}
}
}

View File

@@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented
}
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getConnForURLTimes++

View File

@@ -14,7 +14,7 @@ import (
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string) (net.PacketConn, error)
GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
@@ -25,8 +25,8 @@ type UDPMuxDefault struct {
closedChan chan struct{}
closeOnce sync.Once
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
@@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
@@ -76,7 +77,7 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -84,8 +85,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
return nil, io.ErrClosedPipe
}
if c, ok := m.conns[ufrag]; ok {
return c, nil
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}
c := m.createMuxedConn(ufrag)
@@ -93,26 +94,30 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
<-c.CloseChannel()
m.removeConn(ufrag)
}()
m.conns[ufrag] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
removedConns := make([]*udpMuxedConn, 0, 2)
c := m.conns[key]
delete(m.conns, key)
if c != nil {
removedConns = append(removedConns, c)
}
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
m.addressMapMu.Lock()
@@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for _, c := range m.conns {
for _, c := range m.connsIPv4 {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
for _, c := range m.connsIPv6 {
_ = c.Close()
}
m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)
close(m.closedChan)
})
return err
}
func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
c := func() *udpMuxedConn {
m.mu.Lock()
defer m.mu.Unlock()
if c, ok := m.connsIPv4[key]; ok {
delete(m.connsIPv4, key)
return c
}
if c, ok := m.connsIPv6[key]; ok {
delete(m.connsIPv6, key)
return c
}
return nil
}()
if c == nil {
return
@@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := udpAddr.IP.To4() == nil
m.mu.Lock()
destinationConn = m.conns[ufrag]
destinationConn, _ = m.getConn(ufrag, isIPv6)
m.mu.Unlock()
}
@@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
}
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
type bufferHolder struct {
buffer []byte
}

View File

@@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close())
// can't create more connections
_, err = udpMux.GetConn("failufrag")
_, err = udpMux.GetConn("failufrag", false)
require.Error(t, err)
}
@@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
}
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()

View File

@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
}
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
@@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
}
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.

View File

@@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
}
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()