mirror of
https://github.com/pion/ice.git
synced 2025-09-27 03:45:54 +08:00
Makes UDPMux IPv4/IPv6 aware
UDPMux before only worked with UDP4 traffic. UDP6 traffic would simply be ignored. This commit implements 2 connections per ufrag. When requesting a connection for a ufrag the user must specify if they want IPv4 or IPv6. Relates to pion/webrtc#1915
This commit is contained in:

committed by
Antoine Baché

parent
427ac0fddb
commit
45ff379fd3
10
gather.go
10
gather.go
@@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
return errUDPMuxDisabled
|
||||
}
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
|
||||
switch {
|
||||
case err != nil:
|
||||
return err
|
||||
@@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := a.udpMux.GetConn(a.localUfrag)
|
||||
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
||||
|
||||
for i := range urls {
|
||||
wg.Add(1)
|
||||
go func(url URL, network string) {
|
||||
go func(url URL, network string, isIPv6 bool) {
|
||||
defer wg.Done()
|
||||
|
||||
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
|
||||
@@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String())
|
||||
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
|
||||
if err != nil {
|
||||
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
|
||||
return
|
||||
@@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
||||
}
|
||||
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
|
||||
}
|
||||
}(*urls[i], networkType.String())
|
||||
}(*urls[i], networkType.String(), networkType.IsIPv6())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
|
||||
return nil, errNotImplemented
|
||||
}
|
||||
|
||||
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
|
||||
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getConnForURLTimes++
|
||||
|
87
udp_mux.go
87
udp_mux.go
@@ -14,7 +14,7 @@ import (
|
||||
// UDPMux allows multiple connections to go over a single UDP port
|
||||
type UDPMux interface {
|
||||
io.Closer
|
||||
GetConn(ufrag string) (net.PacketConn, error)
|
||||
GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
|
||||
RemoveConnByUfrag(ufrag string)
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ type UDPMuxDefault struct {
|
||||
closedChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||
conns map[string]*udpMuxedConn
|
||||
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||
|
||||
addressMapMu sync.RWMutex
|
||||
addressMap map[string]*udpMuxedConn
|
||||
@@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
m := &UDPMuxDefault{
|
||||
addressMap: map[string]*udpMuxedConn{},
|
||||
params: params,
|
||||
conns: make(map[string]*udpMuxedConn),
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
connsIPv6: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
pool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
@@ -76,7 +77,7 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -84,8 +85,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
if c, ok := m.conns[ufrag]; ok {
|
||||
return c, nil
|
||||
if conn, ok := m.getConn(ufrag, isIPv6); ok {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
c := m.createMuxedConn(ufrag)
|
||||
@@ -93,26 +94,30 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
||||
<-c.CloseChannel()
|
||||
m.removeConn(ufrag)
|
||||
}()
|
||||
m.conns[ufrag] = c
|
||||
|
||||
if isIPv6 {
|
||||
m.connsIPv6[ufrag] = c
|
||||
} else {
|
||||
m.connsIPv4[ufrag] = c
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
m.mu.Lock()
|
||||
removedConns := make([]*udpMuxedConn, 0)
|
||||
for key := range m.conns {
|
||||
if key != ufrag {
|
||||
continue
|
||||
}
|
||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||
|
||||
c := m.conns[key]
|
||||
delete(m.conns, key)
|
||||
if c != nil {
|
||||
removedConns = append(removedConns, c)
|
||||
}
|
||||
// Keep lock section small to avoid deadlock with conn lock
|
||||
m.mu.Lock()
|
||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||
delete(m.connsIPv4, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
}
|
||||
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||
delete(m.connsIPv6, ufrag)
|
||||
removedConns = append(removedConns, c)
|
||||
}
|
||||
// keep lock section small to avoid deadlock with conn lock
|
||||
m.mu.Unlock()
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
@@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, c := range m.conns {
|
||||
for _, c := range m.connsIPv4 {
|
||||
_ = c.Close()
|
||||
}
|
||||
m.conns = make(map[string]*udpMuxedConn)
|
||||
for _, c := range m.connsIPv6 {
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
m.connsIPv4 = make(map[string]*udpMuxedConn)
|
||||
m.connsIPv6 = make(map[string]*udpMuxedConn)
|
||||
|
||||
close(m.closedChan)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) removeConn(key string) {
|
||||
m.mu.Lock()
|
||||
c := m.conns[key]
|
||||
delete(m.conns, key)
|
||||
// keep lock section small to avoid deadlock with conn lock
|
||||
m.mu.Unlock()
|
||||
c := func() *udpMuxedConn {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if c, ok := m.connsIPv4[key]; ok {
|
||||
delete(m.connsIPv4, key)
|
||||
return c
|
||||
}
|
||||
|
||||
if c, ok := m.connsIPv6[key]; ok {
|
||||
delete(m.connsIPv6, key)
|
||||
return c
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
if c == nil {
|
||||
return
|
||||
@@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
|
||||
}
|
||||
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
isIPv6 := udpAddr.IP.To4() == nil
|
||||
|
||||
m.mu.Lock()
|
||||
destinationConn = m.conns[ufrag]
|
||||
destinationConn, _ = m.getConn(ufrag, isIPv6)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||
if isIPv6 {
|
||||
val, ok = m.connsIPv6[ufrag]
|
||||
} else {
|
||||
val, ok = m.connsIPv4[ufrag]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type bufferHolder struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
@@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
|
||||
require.NoError(t, udpMux.Close())
|
||||
|
||||
// can't create more connections
|
||||
_, err = udpMux.GetConn("failufrag")
|
||||
_, err = udpMux.GetConn("failufrag", false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
|
||||
}
|
||||
|
||||
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
|
||||
pktConn, err := udpMux.GetConn(ufrag)
|
||||
pktConn, err := udpMux.GetConn(ufrag, false)
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
_ = pktConn.Close()
|
||||
|
@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
|
||||
UDPMux
|
||||
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
|
||||
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
|
||||
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
|
||||
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
|
||||
@@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
|
||||
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// and return a unique connection per server.
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
|
||||
}
|
||||
|
||||
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
|
||||
|
@@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
|
||||
pktConn, err := udpMux.GetConn(ufrag)
|
||||
pktConn, err := udpMux.GetConn(ufrag, false)
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
_ = pktConn.Close()
|
||||
|
Reference in New Issue
Block a user