Use MultiUDPMux to implement listen any address

In #475, import low-level API (ReadMsgUDP) to determine
destination interface for packets received by UDPConn listen
at any(unspecified) address, to fix msg received by incorrect
candidate that shared same ufrags. But the api has compatibility
issues, also not reliable in some special network cases like
AWS/ECS.
So this pr revert that change, and make UDPMuxDefault not
accept Conn listen at unspecified address. Also provide a
NewMultiUDPMuxFromPort helper function to create a MultiUDPMux
to listen at all addresses.
For ice gather, it will use UDPMux's listen address to generate
canidates instead of create it from interfaces.
This commit is contained in:
cnderrauber
2022-10-09 00:39:39 +08:00
committed by cnderrauber
parent af9281dc76
commit 04a6027e93
14 changed files with 499 additions and 518 deletions

View File

@@ -24,11 +24,14 @@ func TestMuxAgent(t *testing.T) {
const muxPort = 7686 const muxPort = 7686
c, err := net.ListenUDP("udp4", &net.UDPAddr{ c, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: muxPort, Port: muxPort,
}) })
require.NoError(t, err)
loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory := logging.NewDefaultLoggerFactory()
udpMux := NewUDPMuxDefault(UDPMuxParams{ udpMux, err := NewUDPMuxDefault(UDPMuxParams{
Logger: loggerFactory.NewLogger("ice"), Logger: loggerFactory.NewLogger("ice"),
UDPConn: c, UDPConn: c,
}) })

View File

@@ -134,11 +134,12 @@ var (
errMismatchUsername = errors.New("username mismatch") errMismatchUsername = errors.New("username mismatch")
errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages") errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages")
errUDPMuxDisabled = errors.New("UDPMux is not enabled") errUDPMuxDisabled = errors.New("UDPMux is not enabled")
errCandidateIPNotFound = errors.New("could not determine local IP for Mux candidate")
errNoXorAddrMapping = errors.New("no address mapping") errNoXorAddrMapping = errors.New("no address mapping")
errSendSTUNPacket = errors.New("failed to send STUN packet") errSendSTUNPacket = errors.New("failed to send STUN packet")
errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr")
errNotImplemented = errors.New("not implemented yet") errNotImplemented = errors.New("not implemented yet")
errNoUDPMuxAvailable = errors.New("no UDP mux is available") errNoUDPMuxAvailable = errors.New("no UDP mux is available")
errNoTCPMuxAvailable = errors.New("no TCP mux is available") errNoTCPMuxAvailable = errors.New("no TCP mux is available")
errListenUnspecified = errors.New("can't listen on unspecified address")
errInvalidAddress = errors.New("invalid address")
) )

View File

@@ -273,16 +273,14 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
return errUDPMuxDisabled return errUDPMuxDisabled
} }
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes) localAddresses := a.udpMux.GetListenAddresses()
switch {
case err != nil:
return err
case len(localIPs) == 0:
return errCandidateIPNotFound
}
for _, candidateIP := range localIPs { for _, addr := range localAddresses {
localIP := candidateIP udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return errInvalidAddress
}
candidateIP := udpAddr.IP
if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost { if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if mappedIP, innerErr := a.extIPMapper.findExternalIP(candidateIP.String()); innerErr != nil { if mappedIP, innerErr := a.extIPMapper.findExternalIP(candidateIP.String()); innerErr != nil {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String()) a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
@@ -292,31 +290,10 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
} }
} }
var conns []net.PacketConn conn, err := a.udpMux.GetConn(a.localUfrag, udpAddr)
if multi, ok := a.udpMux.(AllConnsGetter); ok {
conns, err = multi.GetAllConns(a.localUfrag, candidateIP.To4() == nil, localIP)
if err != nil { if err != nil {
return err return err
} }
if len(conns) == 0 {
a.log.Warnf("Failed to get any connections from MultiUDPMux for candidate: %s", candidateIP)
continue
}
} else {
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil, localIP)
if err != nil {
return err
}
conns = []net.PacketConn{conn}
}
for _, conn := range conns {
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host mux candidate: %s failed to cast", candidateIP))
continue
}
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
Network: udp, Network: udp,
Address: candidateIP.String(), Address: candidateIP.String(),
@@ -339,7 +316,6 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
continue continue
} }
} }
}
return nil return nil
} }
@@ -414,8 +390,14 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
} }
for i := range urls { for i := range urls {
for _, listenAddr := range a.udpMuxSrflx.GetListenAddresses() {
udpAddr, ok := listenAddr.(*net.UDPAddr)
if !ok {
a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr")
continue
}
wg.Add(1) wg.Add(1)
go func(url URL, network string, isIPv6 bool) { go func(url URL, network string, localAddr *net.UDPAddr) {
defer wg.Done() defer wg.Done()
hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
@@ -431,7 +413,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return return
} }
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6) conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr)
if err != nil { if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v", network, url, err) a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v", network, url, err)
return return
@@ -440,19 +422,13 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
ip := xoraddr.IP ip := xoraddr.IP
port := xoraddr.Port port := xoraddr.Port
laddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create server reflexive candidate: %s %s %d: cast failed", network, ip, port))
return
}
srflxConfig := CandidateServerReflexiveConfig{ srflxConfig := CandidateServerReflexiveConfig{
Network: network, Network: network,
Address: ip.String(), Address: ip.String(),
Port: port, Port: port,
Component: ComponentRTP, Component: ComponentRTP,
RelAddr: laddr.IP.String(), RelAddr: localAddr.IP.String(),
RelPort: laddr.Port, RelPort: localAddr.Port,
} }
c, err := NewCandidateServerReflexive(&srflxConfig) c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil { if err != nil {
@@ -466,7 +442,8 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
} }
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
} }
}(*urls[i], networkType.String(), networkType.IsIPv6()) }(*urls[i], networkType.String(), udpAddr)
}
} }
} }
} }

View File

@@ -507,9 +507,9 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}() }()
expectedPorts = append(expectedPorts, port) expectedPorts = append(expectedPorts, port)
udpMuxInstances = append(udpMuxInstances, NewUDPMuxDefault(UDPMuxParams{ muxDefault, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn})
UDPConn: conn, assert.NoError(t, err)
})) udpMuxInstances = append(udpMuxInstances, muxDefault)
idx := i idx := i
defer func() { defer func() {
_ = udpMuxInstances[idx].Close() _ = udpMuxInstances[idx].Close()
@@ -675,7 +675,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented return nil, errNotImplemented
} }
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) { func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.getConnForURLTimes++ m.getConnForURLTimes++
@@ -694,3 +694,7 @@ func (m *universalUDPMuxMock) RemoveConnByUfrag(ufrag string) {
defer m.mu.Unlock() defer m.mu.Unlock()
m.removeConnByUfragTimes++ m.removeConnByUfragTimes++
} }
func (m *universalUDPMuxMock) GetListenAddresses() []net.Addr {
return []net.Addr{m.conn.LocalAddr()}
}

View File

@@ -5,6 +5,14 @@ package ice
import "net" import "net"
// AllConnsGetter allows multiple fixed TCP ports to be used,
// each which is multiplexed like TCPMux. AllConnsGetter also acts as
// a TCPMux, in which case it will return a single connection for one
// of the ports.
type AllConnsGetter interface {
GetAllConns(ufrag string, isIPv6 bool, localIP net.IP) ([]net.PacketConn, error)
}
// MultiTCPMuxDefault implements both TCPMux and AllConnsGetter, // MultiTCPMuxDefault implements both TCPMux and AllConnsGetter,
// allowing users to pass multiple TCPMux instances to the ICE agent // allowing users to pass multiple TCPMux instances to the ICE agent
// configuration. // configuration.

View File

@@ -15,8 +15,9 @@ import (
// UDPMux allows multiple connections to go over a single UDP port // UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface { type UDPMux interface {
io.Closer io.Closer
GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string) RemoveConnByUfrag(ufrag string)
GetListenAddresses() []net.Addr
} }
// UDPMuxDefault is an implementation of the interface // UDPMuxDefault is an implementation of the interface
@@ -26,13 +27,11 @@ type UDPMuxDefault struct {
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType // conns are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]map[ipAddr]*udpMuxedConn conns map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
// remote address (ip:port) -> (localip -> udpMuxedConn)
addressMap map[string]map[ipAddr]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes // buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool pool *sync.Pool
@@ -42,37 +41,28 @@ type UDPMuxDefault struct {
const maxAddrSize = 512 const maxAddrSize = 512
// UDPMuxConn is a udp PacketConn with ReadMsgUDP and File method
// to retrieve the destination local address of the received packet
type UDPMuxConn interface {
net.PacketConn
// ReadMsgUdp used to get destination address when received a udp packet
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
// File returns a copy of the underlying os.File.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
File() (f *os.File, err error)
}
// UDPMuxParams are parameters for UDPMux. // UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct { type UDPMuxParams struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn UDPMuxConn UDPConn net.PacketConn
} }
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) (*UDPMuxDefault, error) {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
return nil, errInvalidAddress
} else if ok && addr.IP.IsUnspecified() {
return nil, errListenUnspecified
}
m := &UDPMuxDefault{ m := &UDPMuxDefault{
addressMap: make(map[string]map[ipAddr]*udpMuxedConn), addressMap: map[string]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]map[ipAddr]*udpMuxedConn), conns: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]map[ipAddr]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
@@ -84,7 +74,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
go m.connWorker() go m.connWorker()
return m return m, nil
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this UDPMuxDefault
@@ -92,9 +82,17 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network // GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
if m.params.UDPConn.LocalAddr() != addr {
return nil, errInvalidAddress
}
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -102,60 +100,34 @@ func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.Pa
return nil, io.ErrClosedPipe return nil, io.ErrClosedPipe
} }
if conn, ok := m.getConn(ufrag, isIPv6, local); ok { if conn, ok := m.getConn(ufrag); ok {
return conn, nil return conn, nil
} }
c, err := m.createMuxedConn(ufrag, local) c := m.createMuxedConn(ufrag)
if err != nil {
return nil, err
}
go func() { go func() {
<-c.CloseChannel() <-c.CloseChannel()
m.removeConnByUfragAndLocalHost(ufrag, local) m.RemoveConnByUfrag(ufrag)
}() }()
var ( m.conns[ufrag] = c
conns map[ipAddr]*udpMuxedConn
ok bool
)
if isIPv6 {
if conns, ok = m.connsIPv6[ufrag]; !ok {
conns = make(map[ipAddr]*udpMuxedConn)
m.connsIPv6[ufrag] = conns
}
} else {
if conns, ok = m.connsIPv4[ufrag]; !ok {
conns = make(map[ipAddr]*udpMuxedConn)
m.connsIPv4[ufrag] = conns
}
}
conns[ipAddr(local.String())] = c
return c, nil return c, nil
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 4) var removedConn *udpMuxedConn
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock
m.mu.Lock() m.mu.Lock()
if conns, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.conns[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.conns, ufrag)
for _, c := range conns { removedConn = c
removedConns = append(removedConns, c)
}
}
if conns, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
for _, c := range conns {
removedConns = append(removedConns, c)
}
} }
m.mu.Unlock() m.mu.Unlock()
if len(removedConns) == 0 { if removedConn == nil {
// No need to lock if no connection was found // No need to lock if no connection was found
return return
} }
@@ -163,64 +135,10 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock() defer m.addressMapMu.Unlock()
for _, c := range removedConns { addresses := removedConn.getAddresses()
addresses := c.getAddresses()
for _, addr := range addresses { for _, addr := range addresses {
if conns, ok := m.addressMap[addr]; ok {
delete(conns, ipAddr(c.params.LocalIP.String()))
if len(conns) == 0 {
delete(m.addressMap, addr) delete(m.addressMap, addr)
} }
}
}
}
}
func (m *UDPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) {
removedConns := make([]*udpMuxedConn, 0, 4)
localIP := ipAddr(local.String())
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if conns, ok := m.connsIPv4[ufrag]; ok {
if conn, ok := conns[localIP]; ok {
delete(conns, localIP)
if len(conns) == 0 {
delete(m.connsIPv4, ufrag)
}
removedConns = append(removedConns, conn)
}
}
if conns, ok := m.connsIPv6[ufrag]; ok {
if conn, ok := conns[localIP]; ok {
delete(conns, localIP)
if len(conns) == 0 {
delete(m.connsIPv6, ufrag)
}
removedConns = append(removedConns, conn)
}
}
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
return
}
m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
if conns, ok := m.addressMap[addr]; ok {
delete(conns, ipAddr(c.params.LocalIP.String()))
if len(conns) == 0 {
delete(m.addressMap, addr)
}
}
}
}
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed
@@ -240,40 +158,15 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
for _, conns := range m.connsIPv4 { for _, c := range m.conns {
for _, c := range conns {
_ = c.Close() _ = c.Close()
} }
}
for _, conns := range m.connsIPv6 {
for _, c := range conns {
_ = c.Close()
}
}
m.connsIPv4 = make(map[string]map[ipAddr]*udpMuxedConn) m.conns = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]map[ipAddr]*udpMuxedConn)
// ReadMsgUDP will block until something is received, otherwise it will block forever
// and the Conn's Close method too. So send a packet to wake it for exit.
close(m.closedChan) close(m.closedChan)
closeConn, errConn := net.DialUDP("udp", nil, m.params.UDPConn.LocalAddr().(*net.UDPAddr))
// i386 doesn't support dial local ipv6 address _ = m.params.UDPConn.Close()
if errConn != nil && strings.Contains(errConn.Error(), "dial udp [::]:") &&
strings.Contains(errConn.Error(), "connect: cannot assign requested address") {
closeConn, errConn = net.DialUDP("udp4", nil, &net.UDPAddr{Port: m.params.UDPConn.LocalAddr().(*net.UDPAddr).Port})
}
if errConn != nil {
m.params.Logger.Errorf("Failed to open close notify socket, %v", errConn)
} else {
defer func() {
_ = closeConn.Close()
}()
_, errConn = closeConn.Write([]byte("close"))
if errConn != nil {
m.params.Logger.Errorf("Failed to send close notify msg, %v", errConn)
}
}
}) })
return err return err
} }
@@ -290,58 +183,36 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock() defer m.addressMapMu.Unlock()
conns, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if ok {
existing, ok := conns[ipAddr(conn.params.LocalIP.String())]
if ok { if ok {
existing.removeAddress(addr) existing.removeAddress(addr)
} }
} else { m.addressMap[addr] = conn
conns = make(map[ipAddr]*udpMuxedConn)
m.addressMap[addr] = conns
}
conns[ipAddr(conn.params.LocalIP.String())] = conn
m.params.Logger.Debugf("Registered %s for %s, local %s", addr, conn.params.Key, conn.params.LocalIP.String()) m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
} }
func (m *UDPMuxDefault) createMuxedConn(key string, local net.IP) (*udpMuxedConn, error) { func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
m.params.Logger.Debugf("Creating new muxed connection, key:%s local:%s ", key, local.String())
addr, ok := m.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, ErrGetTransportAddress
}
localAddr := *addr
localAddr.IP = local
c := newUDPMuxedConn(&udpMuxedConnParams{ c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m, Mux: m,
Key: key, Key: key,
AddrPool: m.pool, AddrPool: m.pool,
LocalAddr: &localAddr, LocalAddr: m.LocalAddr(),
LocalIP: local,
Logger: m.params.Logger, Logger: m.params.Logger,
}) })
return c, nil return c
} }
func (m *UDPMuxDefault) connWorker() { //nolint:gocognit func (m *UDPMuxDefault) connWorker() {
logger := m.params.Logger logger := m.params.Logger
defer func() { defer func() {
_ = m.Close() _ = m.Close()
}() }()
localUDPAddr, _ := m.LocalAddr().(*net.UDPAddr)
buf := make([]byte, receiveMTU) buf := make([]byte, receiveMTU)
file, _ := m.params.UDPConn.File()
setUDPSocketOptionsForLocalAddr(file.Fd(), m.params.Logger)
_ = file.Close()
oob := make([]byte, receiveMTU)
for { for {
localHost := localUDPAddr.IP n, addr, err := m.params.UDPConn.ReadFrom(buf)
n, oobn, _, addr, err := m.params.UDPConn.ReadMsgUDP(buf, oob)
if m.IsClosed() { if m.IsClosed() {
return return
} else if err != nil { } else if err != nil {
@@ -354,29 +225,19 @@ func (m *UDPMuxDefault) connWorker() { //nolint:gocognit
return return
} }
// get destination local addr from received packet udpAddr, ok := addr.(*net.UDPAddr)
if oobIP, addrErr := getLocalAddrFromOob(oob[:oobn]); addrErr == nil { if !ok {
localHost = oobIP logger.Errorf("underlying PacketConn did not return a UDPAddr")
} else { return
m.params.Logger.Warnf("could not get local addr from oob: %v, remote %s", addrErr, addr)
} }
// If we have already seen this address dispatch to the appropriate destination // If we have already seen this address dispatch to the appropriate destination
var destinationConn *udpMuxedConn
m.addressMapMu.Lock() m.addressMapMu.Lock()
if conns, ok := m.addressMap[addr.String()]; ok { destinationConn := m.addressMap[addr.String()]
destinationConn, ok = conns[ipAddr(localHost.String())]
if !ok {
for _, c := range conns {
destinationConn = c
break
}
}
}
m.addressMapMu.Unlock() m.addressMapMu.Unlock()
// If we haven't seen this address before but is a STUN packet lookup by ufrag // If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(buf[:n]) && !localHost.IsUnspecified() { if destinationConn == nil && stun.IsMessage(buf[:n]) {
msg := &stun.Message{ msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...), Raw: append([]byte{}, buf[:n]...),
} }
@@ -393,34 +254,25 @@ func (m *UDPMuxDefault) connWorker() { //nolint:gocognit
} }
ufrag := strings.Split(string(attr), ":")[0] ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := addr.IP.To4() == nil
m.mu.Lock() m.mu.Lock()
destinationConn, _ = m.getConn(ufrag, isIPv6, localHost) destinationConn, _ = m.getConn(ufrag)
m.mu.Unlock() m.mu.Unlock()
} }
if destinationConn == nil { if destinationConn == nil {
m.params.Logger.Tracef("dropping packet from %s", addr.String()) m.params.Logger.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String())
continue continue
} }
if err = destinationConn.writePacket(buf[:n], addr); err != nil { if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
m.params.Logger.Errorf("could not write packet: %v", err) m.params.Logger.Errorf("could not write packet: %v", err)
} }
} }
} }
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *udpMuxedConn, ok bool) { func (m *UDPMuxDefault) getConn(ufrag string) (val *udpMuxedConn, ok bool) {
var conns map[ipAddr]*udpMuxedConn val, ok = m.conns[ufrag]
if isIPv6 {
conns, ok = m.connsIPv6[ufrag]
} else {
conns, ok = m.connsIPv4[ufrag]
}
if conns != nil {
val, ok = conns[ipAddr(local.String())]
}
return return
} }

View File

@@ -3,42 +3,44 @@
//nolint:dupl //nolint:dupl
package ice package ice
import "net" import (
"net"
// AllConnsGetter allows multiple fixed UDP or TCP ports to be used, "github.com/pion/logging"
// each which is multiplexed like UDPMux. AllConnsGetter also acts as "github.com/pion/transport/vnet"
// a UDPMux, in which case it will return a single connection for one )
// of the ports.
type AllConnsGetter interface {
GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error)
}
// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter, // MultiUDPMuxDefault implements both UDPMux and AllConnsGetter,
// allowing users to pass multiple UDPMux instances to the ICE agent // allowing users to pass multiple UDPMux instances to the ICE agent
// configuration. // configuration.
type MultiUDPMuxDefault struct { type MultiUDPMuxDefault struct {
muxs []UDPMux muxs []UDPMux
localAddrToMux map[string]UDPMux
} }
// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that // NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances. // uses the provided UDPMux instances.
func NewMultiUDPMuxDefault(muxs ...UDPMux) *MultiUDPMuxDefault { func NewMultiUDPMuxDefault(muxs ...UDPMux) *MultiUDPMuxDefault {
addrToMux := make(map[string]UDPMux)
for _, mux := range muxs {
for _, addr := range mux.GetListenAddresses() {
addrToMux[addr.String()] = mux
}
}
return &MultiUDPMuxDefault{ return &MultiUDPMuxDefault{
muxs: muxs, muxs: muxs,
localAddrToMux: addrToMux,
} }
} }
// GetConn returns a PacketConn given the connection's ufrag and network // GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found. This, unlike // creates the connection if an existing one can't be found.
// GetAllConns, will only return a single PacketConn from the first func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// mux that was passed in to NewMultiUDPMuxDefault. mux, ok := m.localAddrToMux[addr.String()]
func (m *MultiUDPMuxDefault) GetConn(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { if !ok {
// NOTE: We always use the first element here in order to maintain the
// behavior of using an existing connection if one exists.
if len(m.muxs) == 0 {
return nil, errNoUDPMuxAvailable return nil, errNoUDPMuxAvailable
} }
return m.muxs[0].GetConn(ufrag, isIPv6, local) return mux.GetConn(ufrag, addr)
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
@@ -49,26 +51,6 @@ func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
} }
// GetAllConns returns a PacketConn for each underlying UDPMux
func (m *MultiUDPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) {
if len(m.muxs) == 0 {
// Make sure that we either return at least one connection or an error.
return nil, errNoUDPMuxAvailable
}
var conns []net.PacketConn
for _, mux := range m.muxs {
conn, err := mux.GetConn(ufrag, isIPv6, local)
if err != nil {
// For now, this implementation is all or none.
return nil, err
}
if conn != nil {
conns = append(conns, conn)
}
}
return conns, nil
}
// Close the multi mux, no further connections could be created // Close the multi mux, no further connections could be created
func (m *MultiUDPMuxDefault) Close() error { func (m *MultiUDPMuxDefault) Close() error {
var err error var err error
@@ -79,3 +61,146 @@ func (m *MultiUDPMuxDefault) Close() error {
} }
return err return err
} }
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxs {
addrs = append(addrs, mux.GetListenAddresses()...)
}
return addrs
}
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
}
for _, opt := range opts {
opt.apply(&params)
}
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
if err != nil {
return nil, err
}
conns := make([]net.PacketConn, 0, len(ips))
for _, ip := range ips {
conn, listenErr := net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
}
conns = append(conns, conn)
}
if err != nil {
for _, conn := range conns {
_ = conn.Close()
}
return nil, err
}
muxs := make([]UDPMux, 0, len(conns))
for _, conn := range conns {
mux, muxErr := NewUDPMuxDefault(UDPMuxParams{Logger: params.logger, UDPConn: conn})
if muxErr != nil {
err = muxErr
break
}
muxs = append(muxs, mux)
}
if err != nil {
for _, mux := range muxs {
_ = mux.Close()
}
return nil, err
}
return NewMultiUDPMuxDefault(muxs...), nil
}
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort
type UDPMuxFromPortOption interface {
apply(*multiUDPMuxFromPortParam)
}
type multiUDPMuxFromPortParam struct {
ifFilter func(string) bool
ipFilter func(ip net.IP) bool
networks []NetworkType
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
}
type udpMuxFromPortOption struct {
f func(*multiUDPMuxFromPortParam)
}
func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) {
o.f(p)
}
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used
func UDPMuxFromPortWithInterfaceFilter(f func(string) bool) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.ifFilter = f
},
}
}
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used
func UDPMuxFromPortWithIPFilter(f func(ip net.IP) bool) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.ipFilter = f
},
}
}
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6
func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.networks = networks
},
}
}
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size
func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.readBufferSize = size
},
}
}
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size
func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.writeBufferSize = size
},
}
}
// UDPMuxFromPortWithLogger set the logger for the created UDPMux
func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.logger = logger
},
}
}

View File

@@ -5,6 +5,7 @@ package ice
import ( import (
"net" "net"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -20,16 +21,32 @@ func TestMultiUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30) lim := test.TimeOut(time.Second * 30)
defer lim.Stop() defer lim.Stop()
conn1, err := net.ListenUDP(udp, &net.UDPAddr{}) conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err) require.NoError(t, err)
conn2, err := net.ListenUDP(udp, &net.UDPAddr{}) conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err) require.NoError(t, err)
udpMuxMulti := NewMultiUDPMuxDefault( conn3, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}), if err != nil {
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}), // ipv6 is not supported on this machine
) t.Log("ipv6 is not supported on this machine")
}
muxes := []UDPMux{}
muxV41, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1})
require.NoError(t, err)
muxes = append(muxes, muxV41)
muxV42, err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2})
require.NoError(t, err)
muxes = append(muxes, muxV42)
if conn3 != nil {
muxV6, v6err := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn3})
require.NoError(t, v6err)
muxes = append(muxes, muxV6)
}
udpMuxMulti := NewMultiUDPMuxDefault(muxes...)
defer func() { defer func() {
_ = udpMuxMulti.Close() _ = udpMuxMulti.Close()
_ = conn1.Close() _ = conn1.Close()
@@ -60,32 +77,77 @@ func TestMultiUDPMux(t *testing.T) {
require.NoError(t, udpMuxMulti.Close()) require.NoError(t, udpMuxMulti.Close())
// can't create more connections // can't create more connections
_, err = udpMuxMulti.GetConn("failufrag", false, net.IP{}) _, err = udpMuxMulti.GetConn("failufrag", conn1.LocalAddr())
require.Error(t, err) require.Error(t, err)
} }
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) { func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
pktConns, err := udpMuxMulti.GetAllConns(ufrag, false, net.IP{127, 0, 0, 1}) addrs := udpMuxMulti.GetListenAddresses()
pktConns := make([]net.PacketConn, 0, len(addrs))
for _, addr := range addrs {
udpAddr, ok := addr.(*net.UDPAddr)
require.True(t, ok)
if network == "udp4" && udpAddr.IP.To4() == nil {
continue
} else if network == "udp6" && udpAddr.IP.To4() != nil {
continue
}
c, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
pktConns = append(pktConns, c)
}
defer func() { defer func() {
for _, c := range pktConns { for _, c := range pktConns {
_ = c.Close() _ = c.Close()
} }
}() }()
require.Len(t, pktConns, len(udpMuxMulti.muxs), "there should be a PacketConn for every mux")
// Try talking with each PacketConn // Try talking with each PacketConn
for i, pktConn := range pktConns { for _, pktConn := range pktConns {
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr))
Port: pktConn.LocalAddr().(*net.UDPAddr).Port,
})
require.NoError(t, err, "error dialing test udp connection") require.NoError(t, err, "error dialing test udp connection")
localConn, err := udpMuxMulti.muxs[i].GetConn(ufrag, false, remoteConn.RemoteAddr().(*net.UDPAddr).IP) testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
}()
testMuxConnectionPair(t, localConn, remoteConn, ufrag)
} }
} }
func TestUnspecifiedUDPMux(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
muxPort := 7778
udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
return !strings.Contains(s, "docker")
}))
require.NoError(t, err)
require.GreaterOrEqual(t, len(udpMuxMulti.muxs), 1, "at least have 1 muxs")
defer func() {
_ = udpMuxMulti.Close()
}()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", "udp4")
}()
// skip ipv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", "udp6")
}
wg.Wait()
require.NoError(t, udpMuxMulti.Close())
}

View File

@@ -25,10 +25,21 @@ func TestUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30) lim := test.TimeOut(time.Second * 30)
defer lim.Stop() defer lim.Stop()
conn, err := net.ListenUDP(udp, &net.UDPAddr{}) conn4, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err) require.NoError(t, err)
udpMux := NewUDPMuxDefault(UDPMuxParams{ conn6, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv6loopback})
if err != nil {
t.Log("IPv6 is not supported on this machine")
}
for network, c := range map[string]net.PacketConn{"udp4": conn4, "udp6": conn6} {
if c == nil {
continue
}
conn := c
t.Run(network, func(t *testing.T) {
udpMux, err := NewUDPMuxDefault(UDPMuxParams{
Logger: nil, Logger: nil,
UDPConn: conn, UDPConn: conn,
}) })
@@ -49,16 +60,11 @@ func TestUDPMux(t *testing.T) {
defer wg.Done() defer wg.Done()
testMuxConnection(t, udpMux, "ufrag1", udp) testMuxConnection(t, udpMux, "ufrag1", udp)
}() }()
wg.Add(1)
go func() {
defer wg.Done()
testMuxConnection(t, udpMux, "ufrag2", "udp4")
}()
// skip ipv6 test on i386 // skip ipv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63) const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 { if ptrSize != 32 || network != "udp6" {
testMuxConnection(t, udpMux, "ufrag3", "udp6") testMuxConnection(t, udpMux, "ufrag2", network)
} }
wg.Wait() wg.Wait()
@@ -66,8 +72,24 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close()) require.NoError(t, udpMux.Close())
// can't create more connections // can't create more connections
_, err = udpMux.GetConn("failufrag", false, net.IPv4zero) _, err = udpMux.GetConn("failufrag", udpMux.LocalAddr())
require.Error(t, err) require.Error(t, err)
})
}
}
func TestCantMuxUnspecifiedAddr(t *testing.T) {
conn, err := net.ListenUDP(udp, &net.UDPAddr{})
require.NoError(t, err)
_, err = NewUDPMuxDefault(UDPMuxParams{
Logger: nil,
UDPConn: conn,
})
require.Equal(t, errListenUnspecified, err)
_ = conn.Close()
} }
func TestAddressEncoding(t *testing.T) { func TestAddressEncoding(t *testing.T) {
@@ -111,17 +133,15 @@ func TestAddressEncoding(t *testing.T) {
} }
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
Port: udpMux.LocalAddr().(*net.UDPAddr).Port,
})
require.NoError(t, err, "error dialing test udp connection")
pktConn, err := udpMux.GetConn(ufrag, false, remoteConn.RemoteAddr().(*net.UDPAddr).IP)
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {
_ = pktConn.Close() _ = pktConn.Close()
}() }()
remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr))
require.NoError(t, err, "error dialing test udp connection")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag) testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
} }

View File

@@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error)
} }
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
@@ -33,12 +33,12 @@ type UniversalUDPMuxDefault struct {
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. // UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
type UniversalUDPMuxParams struct { type UniversalUDPMuxParams struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn UDPMuxConn UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) (*UniversalUDPMuxDefault, error) {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
@@ -54,7 +54,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &udpConn{
UDPMuxConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
} }
@@ -64,14 +64,18 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
} }
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) muxDefault, err := NewUDPMuxDefault(udpMuxParams)
if err != nil {
return nil, err
}
m.UDPMuxDefault = muxDefault
return m return m, nil
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type udpConn struct {
UDPMuxConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
} }
@@ -84,48 +88,44 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6, net.IPv4zero) return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
} }
// ReadMsgUDP is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. // ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
// It passes processed packets further to the UDPMux (maybe this is not really necessary). // It passes processed packets further to the UDPMux (maybe this is not really necessary).
func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, oobn, flags, addr, err = c.UDPMuxConn.ReadMsgUDP(b, oob) n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil { if err != nil {
return return
} }
if stun.IsMessage(b[:n]) { if stun.IsMessage(p[:n]) {
bytes := make([]byte, n)
copy(bytes, b[:n])
msg := &stun.Message{ msg := &stun.Message{
Raw: bytes, Raw: append([]byte{}, p[:n]...),
} }
if err = msg.Decode(); err != nil { if err = msg.Decode(); err != nil {
c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
err = nil return n, addr, nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// message about this err will be logged in the UDPMux
return return
} }
if c.mux.isXORMappedResponse(msg, addr.String()) { if c.mux.isXORMappedResponse(msg, udpAddr.String()) {
err = c.mux.handleXORMappedResponse(addr, msg) err = c.mux.handleXORMappedResponse(udpAddr, msg)
if err != nil { if err != nil {
c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err)
err = nil return n, addr, nil
return
} }
return return
} }
} }
return return n, addr, err
}
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
oob := make([]byte, 100)
n, _, _, addr, err = c.ReadMsgUDP(p, oob)
return
} }
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -14,10 +14,10 @@ import (
) )
func TestUniversalUDPMux(t *testing.T) { func TestUniversalUDPMux(t *testing.T) {
conn, err := net.ListenUDP(udp, &net.UDPAddr{}) conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err) require.NoError(t, err)
udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ udpMux, err := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{
Logger: nil, Logger: nil,
UDPConn: conn, UDPConn: conn,
}) })
@@ -41,7 +41,7 @@ func TestUniversalUDPMux(t *testing.T) {
} }
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag, false, net.IPv4zero) pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {
_ = pktConn.Close() _ = pktConn.Close()

View File

@@ -16,7 +16,6 @@ type udpMuxedConnParams struct {
AddrPool *sync.Pool AddrPool *sync.Pool
Key string Key string
LocalAddr net.Addr LocalAddr net.Addr
LocalIP net.IP
Logger logging.LeveledLogger Logger logging.LeveledLogger
} }

View File

@@ -1,51 +0,0 @@
//go:build !js && !windows
package ice
import (
"bytes"
"encoding/binary"
"errors"
"net"
"syscall"
"github.com/pion/logging"
)
var errUnknownOobData = errors.New("unknown oob data")
func setUDPSocketOptionsForLocalAddr(fd uintptr, logger logging.LeveledLogger) {
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_2292PKTINFO, 1); err != nil {
logger.Warnf("Failed to set sockopt IPV6_2292PKTINFO: %s", err)
}
if err := syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1); err != nil {
logger.Warnf("Failed to set sockopt IP_PKTINFO: %s", err)
}
}
func getLocalAddrFromOob(oob []byte) (net.IP, error) {
var localHost net.IP
// get destination local addr from received packet
oobBuffer := bytes.NewBuffer(oob)
msg := syscall.Cmsghdr{}
err := binary.Read(oobBuffer, binary.LittleEndian, &msg)
if err == nil {
switch {
case msg.Level == syscall.IPPROTO_IP && msg.Type == syscall.IP_PKTINFO:
packetInfo := syscall.Inet4Pktinfo{}
if err = binary.Read(oobBuffer, binary.LittleEndian, &packetInfo); err == nil {
localHost = net.IP(packetInfo.Addr[:])
return localHost, nil
}
case msg.Level == syscall.IPPROTO_IPV6 && msg.Type == syscall.IPV6_2292PKTINFO:
packetInfo := syscall.Inet6Pktinfo{}
if err = binary.Read(oobBuffer, binary.LittleEndian, &packetInfo); err == nil {
localHost = net.IP(packetInfo.Addr[:])
return localHost, nil
}
default:
return localHost, errUnknownOobData
}
}
return localHost, err
}

View File

@@ -1,19 +0,0 @@
//go:build js || windows
package ice
import (
"errors"
"net"
"github.com/pion/logging"
)
var errUnsupported = errors.New("unsupported")
func setUDPSocketOptionsForLocalAddr(fd uintptr, logger logging.LeveledLogger) {
}
func getLocalAddrFromOob(oob []byte) (net.IP, error) {
return nil, errUnsupported
}