Single port handling via UDPMux

Allows for ICE to handle connections on a single UDP port
This commit is contained in:
David Zhao
2021-04-09 21:27:26 -07:00
parent 6e4403794a
commit 86d69d6ce5
12 changed files with 627 additions and 9 deletions

View File

@@ -62,6 +62,7 @@ Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contribu
* [Assad Obaid](https://github.com/assadobaid)
* [Antoine Baché](https://github.com/Antonito)
* [Will Forcey](https://github.com/wawesomeNOGUI)
* [David Zhao](https://github.com/davidzhao)
### License
MIT License - see [LICENSE](LICENSE) for full text

View File

@@ -123,6 +123,7 @@ type Agent struct {
net *vnet.Net
tcpMux TCPMux
udpMux UDPMux
interfaceFilter func(string) bool
@@ -314,6 +315,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
if a.tcpMux == nil {
a.tcpMux = newInvalidTCPMux()
}
a.udpMux = config.UDPMux
if a.net == nil {
a.net = vnet.NewNet(nil)
@@ -897,6 +899,9 @@ func (a *Agent) Close() error {
a.err.Store(ErrClosed)
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
if a.udpMux != nil {
a.udpMux.RemoveConnByUfrag(a.localUfrag)
}
close(a.done)

View File

@@ -145,6 +145,11 @@ type AgentConfig struct {
// experimental and the API might change in the future.
TCPMux TCPMux
// UDPMux is used for multiplexing multiple incoming UDP connections on a single port
// when this is set, the agent ignores PortMin and PortMax configurations and will
// defer to UDPMux for incoming connections
UDPMux UDPMux
// Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy
// dial interface in order to support corporate proxies
ProxyDialer proxy.Dialer

67
agent_udpmux_test.go Normal file
View File

@@ -0,0 +1,67 @@
// +build !js
package ice
import (
"testing"
"github.com/pion/logging"
"github.com/pion/transport/test"
"github.com/stretchr/testify/require"
)
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux
func TestMuxAgent(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
loggerFactory := logging.NewDefaultLoggerFactory()
udpMux := NewUDPMuxDefault(UDPMuxParams{
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
})
muxPort := 7686
require.NoError(t, udpMux.Start(muxPort))
muxedA, err := NewAgent(&AgentConfig{
UDPMux: udpMux,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
})
require.NoError(t, err)
a, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
})
require.NoError(t, err)
conn, muxedConn := connect(a, muxedA)
pair := muxedA.getSelectedPair()
require.NotNil(t, pair)
require.Equal(t, muxPort, pair.Local.Port())
// send a packet to Mux
data := []byte("hello world")
_, err = conn.Write(data)
require.NoError(t, err)
buffer := make([]byte, 1024)
n, err := muxedConn.Read(buffer)
require.NoError(t, err)
require.Equal(t, data, buffer[:n])
// send a packet from Mux
_, err = muxedConn.Write(data)
require.NoError(t, err)
n, err = conn.Read(buffer)
require.NoError(t, err)
require.Equal(t, data, buffer[:n])
// close it down
require.NoError(t, conn.Close())
require.NoError(t, muxedConn.Close())
require.NoError(t, udpMux.Close())
}

View File

@@ -109,6 +109,9 @@ var (
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists")
// ErrMuxNotStarted indicates the Mux has not been started prior to use
ErrMuxNotStarted = errors.New("mux must be started first")
errSendPacket = errors.New("failed to send packet")
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
errParseComponent = errors.New("could not parse component")

View File

@@ -164,7 +164,6 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
switch network {
case tcp:
// Handle ICE TCP passive mode
a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
if err != nil {
@@ -178,11 +177,19 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
// is there a way to verify that the listen address is even
// accessible from the current interface.
case udp:
if a.udpMux != nil {
conn, err = a.udpMux.GetConnByUfrag(a.localUfrag)
if err != nil {
a.log.Warnf("could not get udp muxed connection: %v\n", err)
continue
}
} else {
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
a.log.Warnf("could not listen %s %s\n", network, ip)
continue
}
}
port = conn.LocalAddr().(*net.UDPAddr).Port
}

2
go.mod
View File

@@ -4,6 +4,7 @@ go 1.13
require (
github.com/google/uuid v1.2.0
github.com/kr/pretty v0.1.0 // indirect
github.com/pion/dtls/v2 v2.0.9
github.com/pion/logging v0.2.2
github.com/pion/mdns v0.0.5
@@ -13,4 +14,5 @@ require (
github.com/pion/turn/v2 v2.0.5
github.com/stretchr/testify v1.7.0
golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)

8
go.sum
View File

@@ -2,6 +2,11 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pion/dtls/v2 v2.0.9 h1:7Ow+V++YSZQMYzggI0P9vLJz/hUFcffsfGMfT/Qy+u8=
github.com/pion/dtls/v2 v2.0.9/go.mod h1:O0Wr7si/Zj5/EBFlDzDd6UtVxx25CE1r7XM7BQKYQho=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
@@ -50,7 +55,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -23,9 +23,8 @@ type TCPMux interface {
RemoveConnByUfrag(ufrag string)
}
// invalidTCPMux is an implementation of TCPMux that always returns ErroTCPMuxNotInitialized.
type invalidTCPMux struct {
}
// invalidTCPMux is an implementation of TCPMux that always returns ErrTCPMuxNotInitialized.
type invalidTCPMux struct{}
func newInvalidTCPMux() *invalidTCPMux {
return &invalidTCPMux{}

251
udp_mux.go Normal file
View File

@@ -0,0 +1,251 @@
package ice
import (
"io"
"net"
"sync"
"github.com/pion/logging"
)
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConnByUfrag(ufrag string) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
listenAddr *net.UDPAddr
udpConn *net.UDPConn
mappingChan chan connMap
closedChan chan struct{}
closeOnce sync.Once
// conns is a map of all udpMuxedConn indexed by ufrag
conns map[string]*udpMuxedConn
// buffer pool to recycle buffers for incoming packets
pool *sync.Pool
mu sync.Mutex
}
type connMap struct {
address string
conn *udpMuxedConn
}
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
ReadBufferSize int
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return &UDPMuxDefault{
params: params,
conns: make(map[string]*udpMuxedConn),
mappingChan: make(chan connMap, 10),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
return make([]byte, receiveMTU)
},
},
}
}
// Start starts the mux. Before the UDPMux is usable, it must be started
func (m *UDPMuxDefault) Start(port int) error {
if m.udpConn != nil {
return ErrMultipleStart
}
m.listenAddr = &net.UDPAddr{
Port: port,
}
uc, err := net.ListenUDP(udp, m.listenAddr)
if err != nil {
return err
}
m.udpConn = uc
go m.connWorker()
return nil
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.listenAddr
}
// GetConnByUfrag returns a PacketConn given the connection's ufrag.
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
if m.udpConn == nil {
return nil, ErrMuxNotStarted
}
m.mu.Lock()
defer m.mu.Unlock()
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
if c, ok := m.conns[ufrag]; ok {
return c, nil
}
c := m.createMuxedConn()
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.conns[ufrag] = c
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
// get addresses to remove
m.mu.Lock()
c := m.conns[ufrag]
delete(m.conns, ufrag)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
if c == nil {
return
}
addresses := c.getAddresses()
for _, addr := range addresses {
m.mappingChan <- connMap{
address: addr,
conn: nil,
}
}
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
return true
default:
return false
}
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
defer m.mu.Unlock()
// close udp conn and prevent packets coming in
err = m.udpConn.Close()
for _, c := range m.conns {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
close(m.closedChan)
})
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) {
return m.udpConn.WriteTo(buf, raddr)
}
func (m *UDPMuxDefault) doneWithBuffer(buf []byte) {
//nolint
m.pool.Put(buf)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
m.mappingChan <- connMap{
address: addr,
conn: conn,
}
}
func (m *UDPMuxDefault) createMuxedConn() *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
ReadBuffer: m.params.ReadBufferSize,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
func (m *UDPMuxDefault) connWorker() {
// map of remote addresses -> udpMuxedConn
// used to look up incoming packets
remoteMap := make(map[string]*udpMuxedConn)
logger := m.params.Logger
defer func() {
_ = m.Close()
}()
for {
buffer := m.pool.Get().([]byte)
n, addr, err := m.udpConn.ReadFrom(buffer)
if err == io.EOF {
return
} else if err != nil {
logger.Errorf("could not read udp packet: %v", err)
return
}
// process any mapping changes
m.applyMappingChanges(remoteMap)
// look up forward destination
addrStr := addr.String()
c := remoteMap[addrStr]
if c == nil {
//nolint
m.pool.Put(buffer)
// ignore packets that we don't know where to route to
continue
}
err = c.writePacket(muxedPacket{
Data: buffer,
Size: n,
RAddr: addr,
})
if err != nil {
logger.Errorf("could not write packet: %v", err)
}
}
}
func (m *UDPMuxDefault) applyMappingChanges(remoteMap map[string]*udpMuxedConn) {
for {
select {
case cm := <-m.mappingChan:
// deregister previous addresses
existingConn := remoteMap[cm.address]
if existingConn != nil {
existingConn.removeAddress(cm.address)
}
remoteMap[cm.address] = cm.conn
default:
return
}
}
}

107
udp_mux_test.go Normal file
View File

@@ -0,0 +1,107 @@
// +build !js
package ice
import (
"net"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/test"
"github.com/stretchr/testify/require"
)
func TestUDPMux(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
loggerFactory := logging.NewDefaultLoggerFactory()
udpMux := NewUDPMuxDefault(UDPMuxParams{
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
})
err := udpMux.Start(7686)
require.NoError(t, err)
defer func() {
_ = udpMux.Close()
}()
require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
require.Equal(t, ":7686", udpMux.LocalAddr().String())
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
testMuxConnection(t, udpMux, "ufrag1")
}()
wg.Add(1)
go func() {
defer wg.Done()
testMuxConnection(t, udpMux, "ufrag2")
}()
testMuxConnection(t, udpMux, "ufrag3")
wg.Wait()
require.NoError(t, udpMux.Close())
// can't create more connections
_, err = udpMux.GetConnByUfrag("failufrag")
require.Error(t, err)
}
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string) {
pktConn, err := udpMux.GetConnByUfrag(ufrag)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
}()
remoteConn, err := net.DialUDP(udp, nil, udpMux.LocalAddr().(*net.UDPAddr))
require.NoError(t, err, "error dialing test udp connection")
// initial messages are dropped
_, err = remoteConn.Write([]byte("dropped bytes"))
require.NoError(t, err)
// wait for packet to be consumed
time.Sleep(time.Millisecond)
// write out to establish connection
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag"))
msg.Encode()
_, err = pktConn.WriteTo(msg.Raw, remoteConn.LocalAddr())
require.NoError(t, err)
// ensure received
buf := make([]byte, receiveMTU)
n, err := remoteConn.Read(buf)
require.NoError(t, err)
require.Equal(t, msg.Raw, buf[:n])
// write a bunch of packets from remote to ensure proper receipt
dataToSend := [][]byte{
[]byte("hello world"),
[]byte("test text"),
msg.Raw,
}
buffer := make([]byte, receiveMTU)
for _, data := range dataToSend {
_, err := remoteConn.Write(data)
require.NoError(t, err)
n, _, err := pktConn.ReadFrom(buffer)
require.NoError(t, err)
require.Equal(t, data, buffer[:n])
time.Sleep(10 * time.Millisecond)
}
}

165
udp_muxed_conn.go Normal file
View File

@@ -0,0 +1,165 @@
package ice
import (
"io"
"net"
"sync"
"time"
"github.com/pion/logging"
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
}
type muxedPacket struct {
Data []byte
RAddr net.Addr
Size int
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
recvChan chan muxedPacket
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
}
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
recvChan: make(chan muxedPacket, params.ReadBuffer),
closedChan: make(chan struct{}),
}
return p
}
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
pkt, ok := <-c.recvChan
if !ok {
return 0, nil, io.ErrClosedPipe
}
if cap(b) < pkt.Size {
return 0, pkt.RAddr, io.ErrShortBuffer
}
copy(b, pkt.Data[:pkt.Size])
c.params.Mux.doneWithBuffer(pkt.Data)
return pkt.Size, pkt.RAddr, err
}
func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
}
// each time we write to a new address, we'll register it with the mux
addr := raddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
}
return c.params.Mux.writeTo(buf, raddr)
}
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
}
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
}
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
}
func (c *udpMuxedConn) Close() error {
c.closeOnce.Do(func() {
close(c.closedChan)
close(c.recvChan)
})
c.mu.Lock()
defer c.mu.Unlock()
c.addresses = nil
return nil
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
default:
return false
}
}
func (c *udpMuxedConn) getAddresses() []string {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
func (c *udpMuxedConn) addAddress(addr string) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
}
func (c *udpMuxedConn) removeAddress(addr string) {
newAddresses := make([]string, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
}
}
c.mu.Lock()
c.addresses = newAddresses
c.mu.Unlock()
}
func (c *udpMuxedConn) containsAddress(addr string) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
}
}
return false
}
func (c *udpMuxedConn) writePacket(pkt muxedPacket) error {
select {
case c.recvChan <- pkt:
return nil
case <-c.closedChan:
return io.ErrClosedPipe
}
}