mirror of
https://github.com/pion/ice.git
synced 2025-09-28 04:12:09 +08:00
Makes UDPMux IPv4/IPv6 aware
UDPMux before only worked with UDP4 traffic. UDP6 traffic would simply be ignored. This commit implements 2 connections per ufrag. When requesting a connection for a ufrag the user must specify if they want IPv4 or IPv6. Relates to pion/webrtc#1915
This commit is contained in:

committed by
Antoine Baché

parent
427ac0fddb
commit
45ff379fd3
10
gather.go
10
gather.go
@@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
|||||||
return errUDPMuxDisabled
|
return errUDPMuxDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
@@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := a.udpMux.GetConn(a.localUfrag)
|
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
|||||||
|
|
||||||
for i := range urls {
|
for i := range urls {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(url URL, network string) {
|
go func(url URL, network string, isIPv6 bool) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
|
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
|
||||||
@@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String())
|
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
|
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
|
||||||
return
|
return
|
||||||
@@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
|
|||||||
}
|
}
|
||||||
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
|
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
|
||||||
}
|
}
|
||||||
}(*urls[i], networkType.String())
|
}(*urls[i], networkType.String(), networkType.IsIPv6())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
|
|||||||
return nil, errNotImplemented
|
return nil, errNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
|
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
m.getConnForURLTimes++
|
m.getConnForURLTimes++
|
||||||
|
87
udp_mux.go
87
udp_mux.go
@@ -14,7 +14,7 @@ import (
|
|||||||
// UDPMux allows multiple connections to go over a single UDP port
|
// UDPMux allows multiple connections to go over a single UDP port
|
||||||
type UDPMux interface {
|
type UDPMux interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
GetConn(ufrag string) (net.PacketConn, error)
|
GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
|
||||||
RemoveConnByUfrag(ufrag string)
|
RemoveConnByUfrag(ufrag string)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,8 +25,8 @@ type UDPMuxDefault struct {
|
|||||||
closedChan chan struct{}
|
closedChan chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
|
||||||
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
|
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||||
conns map[string]*udpMuxedConn
|
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||||
|
|
||||||
addressMapMu sync.RWMutex
|
addressMapMu sync.RWMutex
|
||||||
addressMap map[string]*udpMuxedConn
|
addressMap map[string]*udpMuxedConn
|
||||||
@@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
m := &UDPMuxDefault{
|
m := &UDPMuxDefault{
|
||||||
addressMap: map[string]*udpMuxedConn{},
|
addressMap: map[string]*udpMuxedConn{},
|
||||||
params: params,
|
params: params,
|
||||||
conns: make(map[string]*udpMuxedConn),
|
connsIPv4: make(map[string]*udpMuxedConn),
|
||||||
|
connsIPv6: make(map[string]*udpMuxedConn),
|
||||||
closedChan: make(chan struct{}, 1),
|
closedChan: make(chan struct{}, 1),
|
||||||
pool: &sync.Pool{
|
pool: &sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
@@ -76,7 +77,7 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
|||||||
|
|
||||||
// GetConn returns a PacketConn given the connection's ufrag and network
|
// GetConn returns a PacketConn given the connection's ufrag and network
|
||||||
// creates the connection if an existing one can't be found
|
// creates the connection if an existing one can't be found
|
||||||
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -84,8 +85,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
|||||||
return nil, io.ErrClosedPipe
|
return nil, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
|
|
||||||
if c, ok := m.conns[ufrag]; ok {
|
if conn, ok := m.getConn(ufrag, isIPv6); ok {
|
||||||
return c, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := m.createMuxedConn(ufrag)
|
c := m.createMuxedConn(ufrag)
|
||||||
@@ -93,26 +94,30 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
|||||||
<-c.CloseChannel()
|
<-c.CloseChannel()
|
||||||
m.removeConn(ufrag)
|
m.removeConn(ufrag)
|
||||||
}()
|
}()
|
||||||
m.conns[ufrag] = c
|
|
||||||
|
if isIPv6 {
|
||||||
|
m.connsIPv6[ufrag] = c
|
||||||
|
} else {
|
||||||
|
m.connsIPv4[ufrag] = c
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||||
m.mu.Lock()
|
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||||
removedConns := make([]*udpMuxedConn, 0)
|
|
||||||
for key := range m.conns {
|
|
||||||
if key != ufrag {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
c := m.conns[key]
|
// Keep lock section small to avoid deadlock with conn lock
|
||||||
delete(m.conns, key)
|
m.mu.Lock()
|
||||||
if c != nil {
|
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||||
removedConns = append(removedConns, c)
|
delete(m.connsIPv4, ufrag)
|
||||||
}
|
removedConns = append(removedConns, c)
|
||||||
|
}
|
||||||
|
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||||
|
delete(m.connsIPv6, ufrag)
|
||||||
|
removedConns = append(removedConns, c)
|
||||||
}
|
}
|
||||||
// keep lock section small to avoid deadlock with conn lock
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.Lock()
|
||||||
@@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
for _, c := range m.conns {
|
for _, c := range m.connsIPv4 {
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
m.conns = make(map[string]*udpMuxedConn)
|
for _, c := range m.connsIPv6 {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.connsIPv4 = make(map[string]*udpMuxedConn)
|
||||||
|
m.connsIPv6 = make(map[string]*udpMuxedConn)
|
||||||
|
|
||||||
close(m.closedChan)
|
close(m.closedChan)
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) removeConn(key string) {
|
func (m *UDPMuxDefault) removeConn(key string) {
|
||||||
m.mu.Lock()
|
|
||||||
c := m.conns[key]
|
|
||||||
delete(m.conns, key)
|
|
||||||
// keep lock section small to avoid deadlock with conn lock
|
// keep lock section small to avoid deadlock with conn lock
|
||||||
m.mu.Unlock()
|
c := func() *udpMuxedConn {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if c, ok := m.connsIPv4[key]; ok {
|
||||||
|
delete(m.connsIPv4, key)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
if c, ok := m.connsIPv6[key]; ok {
|
||||||
|
delete(m.connsIPv6, key)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return
|
return
|
||||||
@@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ufrag := strings.Split(string(attr), ":")[0]
|
ufrag := strings.Split(string(attr), ":")[0]
|
||||||
|
isIPv6 := udpAddr.IP.To4() == nil
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
destinationConn = m.conns[ufrag]
|
destinationConn, _ = m.getConn(ufrag, isIPv6)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||||
|
if isIPv6 {
|
||||||
|
val, ok = m.connsIPv6[ufrag]
|
||||||
|
} else {
|
||||||
|
val, ok = m.connsIPv4[ufrag]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
type bufferHolder struct {
|
type bufferHolder struct {
|
||||||
buffer []byte
|
buffer []byte
|
||||||
}
|
}
|
||||||
|
@@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
|
|||||||
require.NoError(t, udpMux.Close())
|
require.NoError(t, udpMux.Close())
|
||||||
|
|
||||||
// can't create more connections
|
// can't create more connections
|
||||||
_, err = udpMux.GetConn("failufrag")
|
_, err = udpMux.GetConn("failufrag", false)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
|
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
|
||||||
pktConn, err := udpMux.GetConn(ufrag)
|
pktConn, err := udpMux.GetConn(ufrag, false)
|
||||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = pktConn.Close()
|
_ = pktConn.Close()
|
||||||
|
@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
|
|||||||
UDPMux
|
UDPMux
|
||||||
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
|
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
|
||||||
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
|
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
|
||||||
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
|
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
|
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
|
||||||
@@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
|
|||||||
|
|
||||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||||
// and return a unique connection per server.
|
// and return a unique connection per server.
|
||||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
|
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
|
||||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
|
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
|
// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
|
||||||
|
@@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
|
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
|
||||||
pktConn, err := udpMux.GetConn(ufrag)
|
pktConn, err := udpMux.GetConn(ufrag, false)
|
||||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = pktConn.Close()
|
_ = pktConn.Close()
|
||||||
|
Reference in New Issue
Block a user