Add multi-port wrappers for UDPMux and TCPMux

These wrappers allow a caller to provide UDPMux and TCPMux instances
to the ICE agent that represent multiple open ports. This can be
desirable in what would otherwise be single-port deployments, as it
increases the chance that one of the fixed ports will not be blocked
by a users firewall.
This commit is contained in:
Kevin Caffrey
2022-09-04 17:21:20 -04:00
parent be69d2c2ae
commit 169ff6a7b4
8 changed files with 592 additions and 67 deletions

View File

@@ -139,4 +139,6 @@ var (
errSendSTUNPacket = errors.New("failed to send STUN packet")
errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr")
errNotImplemented = errors.New("not implemented yet")
errNoUDPMuxAvailable = errors.New("no UDP mux is available")
errNoTCPMuxAvailable = errors.New("no TCP mux is available")
)

View File

@@ -158,7 +158,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
for _, ip := range localIPs {
mappedIP := ip
if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if _mappedIP, err := a.extIPMapper.findExternalIP(ip.String()); err == nil {
if _mappedIP, innerErr := a.extIPMapper.findExternalIP(ip.String()); innerErr == nil {
mappedIP = _mappedIP
} else {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", ip.String())
@@ -171,70 +171,93 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}
for network := range networks {
var (
port int
type connAndPort struct {
conn net.PacketConn
err error
port int
}
var (
conns []connAndPort
tcpType TCPType
)
switch network {
case tcp:
// Handle ICE TCP passive mode
var muxConns []net.PacketConn
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting all tcp conns by ufrag: %s %s %s", network, ip, a.localUfrag)
}
continue
}
} else {
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting tcp conn by ufrag: %s %s %s", network, ip, a.localUfrag)
}
continue
}
muxConns = []net.PacketConn{conn}
}
// Extract the port for each PacketConn we got.
for _, conn := range muxConns {
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
port = tcpConn.Port
conns = append(conns, connAndPort{conn, tcpConn.Port})
} else {
a.log.Warnf("failed to get port of conn from TCPMux: %s %s %s", network, ip, a.localUfrag)
}
}
if len(conns) == 0 {
// Didn't succeed with any, try the next network.
continue
}
tcpType = TCPTypePassive
// is there a way to verify that the listen address is even
// accessible from the current interface.
case udp:
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
a.log.Warnf("could not listen %s %s", network, ip)
continue
}
if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok {
port = udpConn.Port
conns = append(conns, connAndPort{conn, udpConn.Port})
} else {
a.log.Warnf("failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag)
continue
}
}
for _, connAndPort := range conns {
hostConfig := CandidateHostConfig{
Network: network,
Address: address,
Port: port,
Port: connAndPort.port,
Component: ComponentRTP,
TCPType: tcpType,
}
c, err := NewCandidateHost(&hostConfig)
if err != nil {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, port, err))
closeConnAndLog(connAndPort.conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err))
continue
}
if a.mDNSMode == MulticastDNSModeQueryAndGather {
if err = c.setIP(ip); err != nil {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, port, err))
closeConnAndLog(connAndPort.conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err))
continue
}
}
if err := a.addCandidate(ctx, c, conn); err != nil {
if err := a.addCandidate(ctx, c, connAndPort.conn); err != nil {
if closeErr := c.close(); closeErr != nil {
a.log.Warnf("Failed to close candidate: %v", closeErr)
}
@@ -243,8 +266,9 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}
}
}
}
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
if a.udpMux == nil {
return errUDPMuxDisabled
}
@@ -259,7 +283,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
for _, candidateIP := range localIPs {
if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String()); err != nil {
if mappedIP, innerErr := a.extIPMapper.findExternalIP(candidateIP.String()); innerErr != nil {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
continue
} else {
@@ -267,11 +291,25 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
}
}
var conns []net.PacketConn
if multi, ok := a.udpMux.(AllConnsGetter); ok {
conns, err = multi.GetAllConns(a.localUfrag, candidateIP.To4() == nil)
if err != nil {
return err
}
if len(conns) == 0 {
a.log.Warnf("Failed to get any connections from MultiUDPMux for candidate: %s", candidateIP)
continue
}
} else {
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
if err != nil {
return err
}
conns = []net.PacketConn{conn}
}
for _, conn := range conns {
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host mux candidate: %s failed to cast", candidateIP))
@@ -300,6 +338,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
continue
}
}
}
return nil
}
@@ -585,7 +624,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli
return
}
conn, connectErr := dtls.Dial(network, udpAddr, &dtls.Config{
conn, connectErr := dtls.Dial(network, udpAddr, &dtls.Config{ //nolint:contextcheck
ServerName: url.Host,
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
})

View File

@@ -488,6 +488,119 @@ func TestTURNProxyDialer(t *testing.T) {
assert.NoError(t, a.Close())
}
// Assert that candidates are given for each mux in a MultiUDPMux
func TestMultiUDPMuxUsage(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
var expectedPorts []int
var udpMuxInstances []UDPMux
for i := 0; i < 3; i++ {
port := randomPort(t)
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: port})
assert.NoError(t, err)
defer func() {
_ = conn.Close()
}()
expectedPorts = append(expectedPorts, port)
udpMuxInstances = append(udpMuxInstances, NewUDPMuxDefault(UDPMuxParams{
UDPConn: conn,
}))
}
a, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
CandidateTypes: []CandidateType{CandidateTypeHost},
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
})
assert.NoError(t, err)
candidateCh := make(chan Candidate)
assert.NoError(t, a.OnCandidate(func(c Candidate) {
if c == nil {
close(candidateCh)
return
}
candidateCh <- c
}))
assert.NoError(t, a.GatherCandidates())
portFound := make(map[int]bool)
for c := range candidateCh {
portFound[c.Port()] = true
assert.True(t, c.NetworkType().IsUDP(), "All candidates should be UDP")
}
assert.Len(t, portFound, len(expectedPorts))
for _, port := range expectedPorts {
assert.True(t, portFound[port], "There should be a candidate for each UDP mux port")
}
assert.NoError(t, a.Close())
}
// Assert that candidates are given for each mux in a MultiTCPMux
func TestMultiTCPMuxUsage(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
var expectedPorts []int
var tcpMuxInstances []TCPMux
for i := 0; i < 3; i++ {
port := randomPort(t)
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: port,
})
assert.NoError(t, err)
defer func() {
_ = listener.Close()
}()
expectedPorts = append(expectedPorts, port)
tcpMuxInstances = append(tcpMuxInstances, NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
ReadBufferSize: 8,
}))
}
a, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
CandidateTypes: []CandidateType{CandidateTypeHost},
TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...),
})
assert.NoError(t, err)
candidateCh := make(chan Candidate)
assert.NoError(t, a.OnCandidate(func(c Candidate) {
if c == nil {
close(candidateCh)
return
}
candidateCh <- c
}))
assert.NoError(t, a.GatherCandidates())
portFound := make(map[int]bool)
for c := range candidateCh {
if c.NetworkType().IsTCP() {
portFound[c.Port()] = true
}
}
assert.Len(t, portFound, len(expectedPorts))
for _, port := range expectedPorts {
assert.True(t, portFound[port], "There should be a candidate for each TCP mux port")
}
assert.NoError(t, a.Close())
}
// Assert that UniversalUDPMux is used while gathering when configured in the Agent
func TestUniversalUDPMuxUsage(t *testing.T) {
report := test.CheckRoutines(t)

73
tcp_mux_multi.go Normal file
View File

@@ -0,0 +1,73 @@
// Package ice ...
//
//nolint:dupl
package ice
import "net"
// MultiTCPMuxDefault implements both TCPMux and AllConnsGetter,
// allowing users to pass multiple TCPMux instances to the ICE agent
// configuration.
type MultiTCPMuxDefault struct {
muxs []TCPMux
}
// NewMultiTCPMuxDefault creates an instance of MultiTCPMuxDefault that
// uses the provided TCPMux instances.
func NewMultiTCPMuxDefault(muxs ...TCPMux) *MultiTCPMuxDefault {
return &MultiTCPMuxDefault{
muxs: muxs,
}
}
// GetConnByUfrag returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found. This, unlike
// GetAllConns, will only return a single PacketConn from the first mux that was
// passed in to NewMultiTCPMuxDefault.
func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
// NOTE: We always use the first element here in order to maintain the
// behavior of using an existing connection if one exists.
if len(m.muxs) == 0 {
return nil, errNoTCPMuxAvailable
}
return m.muxs[0].GetConnByUfrag(ufrag, isIPv6)
}
// RemoveConnByUfrag stops and removes the muxed packet connection
// from all underlying TCPMux instances.
func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) {
for _, mux := range m.muxs {
mux.RemoveConnByUfrag(ufrag)
}
}
// GetAllConns returns a PacketConn for each underlying TCPMux
func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error) {
if len(m.muxs) == 0 {
// Make sure that we either return at least one connection or an error.
return nil, errNoTCPMuxAvailable
}
var conns []net.PacketConn
for _, mux := range m.muxs {
conn, err := mux.GetConnByUfrag(ufrag, isIPv6)
if err != nil {
// For now, this implementation is all or none.
return nil, err
}
if conn != nil {
conns = append(conns, conn)
}
}
return conns, nil
}
// Close the multi mux, no further connections could be created
func (m *MultiTCPMuxDefault) Close() error {
var err error
for _, mux := range m.muxs {
if e := mux.Close(); e != nil {
err = e
}
}
return err
}

128
tcp_mux_multi_test.go Normal file
View File

@@ -0,0 +1,128 @@
//go:build !js
// +build !js
package ice
import (
"io"
"net"
"testing"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMultiTCPMux_Recv(t *testing.T) {
for name, buffersize := range map[string]int{
"no buffer": 0,
"buffered 4MB": 4 * 1024 * 1024,
} {
bufSize := buffersize
t.Run(name, func(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
loggerFactory := logging.NewDefaultLoggerFactory()
var muxInstances []TCPMux
for i := 0; i < 3; i++ {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()
tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
WriteBufferSize: bufSize,
})
muxInstances = append(muxInstances, tcpMux)
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
}
multiMux := NewMultiTCPMuxDefault(muxInstances...)
defer func() {
_ = multiMux.Close()
}()
pktConns, err := multiMux.GetAllConns("myufrag", false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
for _, pktConn := range pktConns {
defer func() {
_ = pktConn.Close()
}()
conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test tcp connection")
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
msg.Encode()
n, err := writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing tcp stun packet")
recv := make([]byte, n)
n2, raddr, err := pktConn.ReadFrom(recv)
require.NoError(t, err, "error receiving data")
assert.Equal(t, conn.LocalAddr(), raddr, "remote tcp address mismatch")
assert.Equal(t, n, n2, "received byte size mismatch")
assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
// check echo response
n, err = pktConn.WriteTo(recv, conn.LocalAddr())
require.NoError(t, err, "error writing echo stun packet")
recvEcho := make([]byte, n)
n3, err := readStreamingPacket(conn, recvEcho)
require.NoError(t, err, "error receiving echo data")
assert.Equal(t, n2, n3, "received byte size mismatch")
assert.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
}
})
}
}
func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
loggerFactory := logging.NewDefaultLoggerFactory()
var tcpMuxInstances []TCPMux
for i := 0; i < 3; i++ {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()
tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
})
tcpMuxInstances = append(tcpMuxInstances, tcpMux)
}
muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...)
_, err := muxMulti.GetAllConns("test", false)
require.NoError(t, err, "error getting conn by ufrag")
require.NoError(t, muxMulti.Close(), "error closing tcpMux")
conn, err := muxMulti.GetAllConns("test", false)
assert.Nil(t, conn, "should receive nil because mux is closed")
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}

81
udp_mux_multi.go Normal file
View File

@@ -0,0 +1,81 @@
// Package ice ...
//
//nolint:dupl
package ice
import "net"
// AllConnsGetter allows multiple fixed UDP or TCP ports to be used,
// each which is multiplexed like UDPMux. AllConnsGetter also acts as
// a UDPMux, in which case it will return a single connection for one
// of the ports.
type AllConnsGetter interface {
GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error)
}
// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter,
// allowing users to pass multiple UDPMux instances to the ICE agent
// configuration.
type MultiUDPMuxDefault struct {
muxs []UDPMux
}
// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances.
func NewMultiUDPMuxDefault(muxs ...UDPMux) *MultiUDPMuxDefault {
return &MultiUDPMuxDefault{
muxs: muxs,
}
}
// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found. This, unlike
// GetAllConns, will only return a single PacketConn from the first
// mux that was passed in to NewMultiUDPMuxDefault.
func (m *MultiUDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
// NOTE: We always use the first element here in order to maintain the
// behavior of using an existing connection if one exists.
if len(m.muxs) == 0 {
return nil, errNoUDPMuxAvailable
}
return m.muxs[0].GetConn(ufrag, isIPv6)
}
// RemoveConnByUfrag stops and removes the muxed packet connection
// from all underlying UDPMux instances.
func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
for _, mux := range m.muxs {
mux.RemoveConnByUfrag(ufrag)
}
}
// GetAllConns returns a PacketConn for each underlying UDPMux
func (m *MultiUDPMuxDefault) GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error) {
if len(m.muxs) == 0 {
// Make sure that we either return at least one connection or an error.
return nil, errNoUDPMuxAvailable
}
var conns []net.PacketConn
for _, mux := range m.muxs {
conn, err := mux.GetConn(ufrag, isIPv6)
if err != nil {
// For now, this implementation is all or none.
return nil, err
}
if conn != nil {
conns = append(conns, conn)
}
}
return conns, nil
}
// Close the multi mux, no further connections could be created
func (m *MultiUDPMuxDefault) Close() error {
var err error
for _, mux := range m.muxs {
if e := mux.Close(); e != nil {
err = e
}
}
return err
}

85
udp_mux_multi_test.go Normal file
View File

@@ -0,0 +1,85 @@
//go:build !js
// +build !js
package ice
import (
"net"
"sync"
"testing"
"time"
"github.com/pion/transport/test"
"github.com/stretchr/testify/require"
)
func TestMultiUDPMux(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
conn1, err := net.ListenUDP(udp, &net.UDPAddr{})
require.NoError(t, err)
conn2, err := net.ListenUDP(udp, &net.UDPAddr{})
require.NoError(t, err)
udpMuxMulti := NewMultiUDPMuxDefault(
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}),
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}),
)
defer func() {
_ = udpMuxMulti.Close()
_ = conn1.Close()
_ = conn2.Close()
}()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", "udp4")
}()
// skip ipv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", "udp6")
}
wg.Wait()
require.NoError(t, udpMuxMulti.Close())
// can't create more connections
_, err = udpMuxMulti.GetConn("failufrag", false)
require.Error(t, err)
}
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
pktConns, err := udpMuxMulti.GetAllConns(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
for _, c := range pktConns {
_ = c.Close()
}
}()
require.Len(t, pktConns, len(udpMuxMulti.muxs), "there should be a PacketConn for every mux")
// Try talking with each PacketConn
for _, pktConn := range pktConns {
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{
Port: pktConn.LocalAddr().(*net.UDPAddr).Port,
})
require.NoError(t, err, "error dialing test udp connection")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
}
}

View File

@@ -122,8 +122,12 @@ func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, networ
})
require.NoError(t, err, "error dialing test udp connection")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
}
func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
// initial messages are dropped
_, err = remoteConn.Write([]byte("dropped bytes"))
_, err := remoteConn.Write([]byte("dropped bytes"))
require.NoError(t, err)
// wait for packet to be consumed
time.Sleep(time.Millisecond)